Add TraceManager around frontend opt

From: @irmo
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-30 10:05:06 +08:00 committed by Gitee
commit 6c287c28ca
22 changed files with 173 additions and 164 deletions

View File

@ -517,12 +517,11 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (iter != python_paras->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (param_value != nullptr) {
(*python_paras)[param_value] = new_parameter;
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
graph_inputs->push_back(new_parameter);
@ -627,9 +626,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
std::vector<AnfNodePtr> cnode_inputs;
GetCNodeInfo(cnode, &cnode_inputs);
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();
return new_cnode;
}
@ -806,9 +804,8 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
// handle inputs of cnode except primitive
CreateCNodeInputs(cnode, graph, &cnode_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();
// if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) {
@ -865,12 +862,11 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
if (iter != python_paras->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (param_value != nullptr) {
(*python_paras)[param_value] = new_parameter;
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();

View File

@ -659,9 +659,11 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr ptr_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptr_graph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
FuncGraphPtr df_builder = nullptr;
{
TraceGuard g(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
df_builder = std::make_shared<FuncGraph>();
}
auto nparam = ptr_graph->parameters().size();
std::ostringstream ss;
@ -680,9 +682,11 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
inputs.push_back(param_graph);
auto jf = df_builder->NewCNode(inputs);
// df is checked in GetGrad
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
auto df = GetGrad(jf, weights, ptr_graph->parameters());
TraceManager::EndTrace();
FuncGraphPtr df = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
df = GetGrad(jf, weights, ptr_graph->parameters());
}
df_builder->set_output(NewValueNode(df));
return df_builder;

View File

@ -41,8 +41,10 @@ FuncGraphSet DFunctor::scope_;
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>();
{
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>();
}
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
@ -50,17 +52,17 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
// To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage());
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
tape_ = std::make_shared<FuncGraph>();
{
TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
tape_ = std::make_shared<FuncGraph>();
}
tape_->set_stage(primal_graph->stage());
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
TraceManager::EndTrace();
dout_ = tape_->add_parameter();
}
@ -232,9 +234,8 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr;
if (IsValueNode<Primitive>(node)) {
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
k = MapToK(node);
TraceManager::EndTrace();
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint;
} else {
@ -254,9 +255,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
inputs.push_back(k);
param_adjoints.push_back(node_adjoint);
}
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
auto k_app = k_graph_->NewCNode(inputs);
TraceManager::EndTrace();
CNodePtr k_app = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
k_app = k_graph_->NewCNode(inputs);
}
ReplaceEquivdout(k_app, cnode_morph);
cnode_morph->set_forward(nullptr, "");
for (size_t i = 0; i < param_adjoints.size(); ++i) {
@ -624,9 +627,8 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
}
if (primal->isa<Parameter>()) {
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
auto ret = k_graph_->add_parameter();
TraceManager::EndTrace();
return ret;
}
@ -812,9 +814,8 @@ void DFunctor::EliminatePrimalGraph() {
}
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
auto construct_wrapper = cnode->func_graph();
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceGradFpropApp>(cnode->debug_info()));
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
TraceManager::EndTrace();
manager->Replace(cnode, getitem0);
}
}

View File

@ -173,11 +173,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
AnfNodePtr bout = BuildOutput(cloned_bprop_fg);
cloned_bprop_fg->set_output(bout);
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(debug_info));
auto outer = std::make_shared<FuncGraph>();
(void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
outer->set_output(NewValueNode(kNone));
TraceManager::EndTrace();
FuncGraphPtr outer = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradFprop>(debug_info));
outer = std::make_shared<FuncGraph>();
(void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
outer->set_output(NewValueNode(kNone));
}
auto mng = Manage({cloned_bprop_fg, outer}, false);
@ -199,13 +201,12 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
(void)mng->Replace(out_param, out_value);
TraceManager::DebugTrace(std::make_shared<TraceGradSens>(out_param->debug_info()));
TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
auto new_dout = cloned_bprop_fg->add_parameter();
(void)mng->Replace(dout, new_dout);
// We remove all parameters except new_dout.
std::vector<AnfNodePtr> newBpropParams = {new_dout};
cloned_bprop_fg->set_parameters(newBpropParams);
TraceManager::EndTrace();
outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
return BasicClone(outer);

View File

@ -181,9 +181,8 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
auto p = bprop_fg->parameters()[i];
MS_EXCEPTION_IF_NULL(p);
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(p->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
auto transf_p = outer->add_parameter();
TraceManager::EndTrace();
(void)mng->Replace(p, transf_p);
transf_args->push_back(transf_p);

View File

@ -42,9 +42,8 @@ class ExpandJPrim : public AnfVisitor {
x_ = nullptr;
AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node);
if (x_ != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceExpandJ>(node->debug_info()));
TraceGuard guard(std::make_shared<TraceExpandJ>(node->debug_info()));
auto j_node = internal::ExpandJ(x_, optimizer->resource());
TraceManager::EndTrace();
return j_node;
}
return nullptr;

View File

@ -49,9 +49,8 @@ class PartialEliminater : public AnfVisitor {
std::vector<AnfNodePtr> args{};
(void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
TraceManager::DebugTrace(std::make_shared<TracePartialTransform>(node->debug_info()));
TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
auto new_node = node->func_graph()->NewCNode(args);
TraceManager::EndTrace();
return new_node;
}

View File

@ -133,6 +133,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
// apply transform on this node
bool change = false;
if (is_match) {
TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
auto ret = (*transform)(optimizer, node);
if (ret != nullptr && ret != node) {
change = true;

View File

@ -172,9 +172,8 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
}
auto &cloned_nodes = *cloner->cloned_node();
for (auto &fv : fg->paramter_obj_nodes()) {
TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter();
TraceManager::EndTrace();
auto &node_users = res->manager()->node_users()[fv];
for (auto &n : node_users) {
// If the user is not in this graph, no need to change.

View File

@ -76,9 +76,8 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
// If have more than one predecessor blocks then build a phi node.
auto debug_info = std::make_shared<NodeDebugInfo>();
debug_info->set_name(var);
TraceManager::DebugTrace(std::make_shared<TracePhi>(debug_info));
TraceGuard guard(std::make_shared<TracePhi>(debug_info));
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
TraceManager::EndTrace();
MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var;
func_graph()->add_parameter(phi_param);
phi_nodes_[phi_param] = var;
@ -264,16 +263,14 @@ void FunctionBlock::Mature() {
// Force the conditIon node to bool using bool operation
CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
TraceManager::DebugTrace(std::make_shared<TraceForceBool>(cond->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceForceBool>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
TraceManager::EndTrace();
return op_apply_node;
}
CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
TraceManager::DebugTrace(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond});
TraceManager::EndTrace();
return op_apply_node;
}

View File

@ -238,10 +238,9 @@ void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py
continue;
}
}
TraceManager::DebugTrace(GetLocation(args[i]));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block->func_graph());
MS_EXCEPTION_IF_NULL(para_node);
TraceManager::EndTrace();
para_node->set_name(arg_name);
para_node->debug_info()->set_name(arg_name);
block->func_graph()->add_parameter(para_node);
@ -346,9 +345,8 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::ob
MS_LOG(DEBUG) << "The nodes count is " << count;
for (size_t i = 0; i < count; i++) {
auto node = py::cast<py::list>(nodes)[i];
TraceManager::DebugTrace(GetLocation(node));
TraceGuard guard(GetLocation(node));
fn_block = ParseStatement(fn_block, node);
TraceManager::EndTrace();
// insert appropriate depended items for the function block if it has a return node
if (fn_block->func_graph()->get_return() != nullptr) {
fn_block->InsertDependItemsBeforeReturn();
@ -372,9 +370,8 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
std::string node_name = node_type->node_name();
MS_LOG(DEBUG) << "Ast node is " << node_name;
if (stmt_method_map_.count(node_name)) {
TraceManager::DebugTrace(GetLocation(node));
TraceGuard trace_guard(GetLocation(node));
auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
TraceManager::EndTrace();
return stmt_block;
} else {
errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
@ -406,9 +403,8 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
std::string node_name = node_type->node_name();
MS_LOG(DEBUG) << "Ast node is " << node_name;
if (expr_method_map_.count(node_name)) {
TraceManager::DebugTrace(GetLocation(node));
TraceGuard trace_guard(GetLocation(node));
auto expr_node = (this->*expr_method_map_[node_name])(block, node);
TraceManager::EndTrace();
return expr_node;
} else {
errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
@ -756,9 +752,8 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// process the node attr
auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
MS_LOG(DEBUG) << "Attr = " << attr_str;
TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
AnfNodePtr attr_node = NewValueNode(attr_str);
TraceManager::EndTrace();
// create the apply node
return block->func_graph()->NewCNode({op_node, value_node, attr_node});
@ -799,12 +794,16 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
rest.append(value_list[i]);
}
MS_EXCEPTION_IF_NULL(block);
TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
FunctionBlockPtr b1, b2;
@ -874,9 +873,8 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
py::list args = ast_->GetArgs(node);
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg = py::cast<std::string>(args[i].attr("arg"));
TraceManager::DebugTrace(GetLocation(args[i]));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(func_block->func_graph());
TraceManager::EndTrace();
para_node->debug_info()->set_name(arg);
func_block->func_graph()->add_parameter(para_node);
func_block->WriteVariable(arg, para_node);
@ -1065,19 +1063,24 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
MS_EXCEPTION_IF_NULL(block);
CNodePtr bool_node = block->ForceToBoolNode(condition_node);
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
if (MsContext::GetInstance()->backend_policy() != "ge") {
// for backends excludes 'ge', it can handle multi graph call, use this flag to
@ -1112,17 +1115,21 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
MS_LOG(DEBUG) << "Process ast While";
MS_EXCEPTION_IF_NULL(block);
MS_LOG(INFO) << "Parse while statement";
TraceManager::DebugTrace(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
FunctionBlockPtr header_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
FunctionBlockPtr body_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr header_block = nullptr;
FunctionBlockPtr body_block = nullptr;
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
header_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
body_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
body_block->AddPrevBlock(header_block);
after_block->AddPrevBlock(header_block);
@ -1169,9 +1176,8 @@ CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const Functio
}
FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
TraceManager::DebugTrace(trace_info);
TraceGuard trace_guard(trace_info);
FunctionBlockPtr body_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
return body_block;
}
@ -1195,19 +1201,24 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
FunctionBlockPtr true_end = ParseForIter(true_block, node);
true_end->Jump(after_block, nullptr);
@ -1263,10 +1274,12 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
iter2_app->debug_info()->set_trace_info(it_info);
iter_apply->debug_info()->set_trace_info(it_info);
TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
MS_EXCEPTION_IF_NULL(after_block);
TraceManager::EndTrace();
after_block->AddPrevBlock(header_block);
block->Jump(header_block, iter_apply);
@ -1350,10 +1363,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
loop_var->debug_info()->set_trace_info(it_info);
len_iter->debug_info()->set_trace_info(it_info);
TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
MS_EXCEPTION_IF_NULL(after_block);
TraceManager::EndTrace();
after_block->AddPrevBlock(header_block);
block->Jump(header_block, NewValueNode(static_cast<int64_t>(0)));
@ -1389,13 +1404,16 @@ AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &n
AnfNodePtr condition_node = ParseExprNode(block, test_node);
CNodePtr bool_node = block->ForceToBoolNode(condition_node);
TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
@ -1581,9 +1599,8 @@ FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::obj
Loop &loop = loops_.top();
if (loop.end == nullptr) {
// Create end_block if it is not existed.
TraceManager::DebugTrace(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
loop.end = MakeFunctionBlock(*this);
TraceManager::EndTrace();
}
// Jump to the end_block.
block->Jump(loop.end, nullptr);

View File

@ -221,7 +221,7 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
const AnfNodePtr &node) {
ScopeGuard scope_guard(node->scope());
AnfNodePtr resolved_node = nullptr;
TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node);
if (!success) {
MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
@ -236,8 +236,6 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
&resolved_node);
}
TraceManager::EndTrace();
return resolved_node;
}
} // namespace

View File

@ -47,6 +47,7 @@
#include "frontend/optimizer/py_pass_manager.h"
#include "pybind_api/pybind_patch.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/common.h"
#include "ps/util.h"

View File

@ -208,9 +208,8 @@ FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const Ab
if (iter == func_graph_cache_.end()) {
auto fg = func_graph();
MS_EXCEPTION_IF_NULL(fg);
TraceManager::DebugTrace(std::make_shared<TraceEvaluatorGenGraph>(fg->debug_info()));
TraceGuard guard(std::make_shared<TraceEvaluatorGenGraph>(fg->debug_info()));
FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
TraceManager::EndTrace();
func_graph_cache_[args_spec_list] = generated_graph;
MS_EXCEPTION_IF_NULL(engine);
engine->func_graph_manager()->AddFuncGraph(generated_graph);
@ -237,9 +236,8 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
MS_EXCEPTION_IF_NULL(meta_func_graph_);
FuncGraphPtr generated_func_graph = nullptr;
if (this->bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
TraceManager::EndTrace();
} else {
generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
}

View File

@ -88,10 +88,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodePtr new_cnode = nullptr;
if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
TraceManager::EndTrace();
} else {
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
@ -936,9 +935,8 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
}
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
TraceManager::EndTrace();
} else {
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
}
@ -962,9 +960,8 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
}
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
TraceManager::EndTrace();
} else {
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
}

View File

@ -115,10 +115,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}
TraceManager::DebugTrace(
std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
auto fg = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
// Merge CNodes into a AnfGraph that represents a linear instruction segment
@ -152,9 +150,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
}
TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
TraceGuard tg(std::make_shared<TraceSegmentTransform>(n->debug_info()));
eqv[n] = fg->NewCNode(args);
TraceManager::EndTrace();
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}

View File

@ -66,7 +66,7 @@ void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
TraceGuard trace_guard(node->debug_info(), relation_);
auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
auto old_param = node->cast<ParameterPtr>();
new_param->set_abstract(old_param->abstract());
@ -78,13 +78,12 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_param->set_scope(scope);
repl_node_[node] = new_param;
TraceManager::EndTrace();
}
void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
TraceGuard trace_guard(node->debug_info(), relation_);
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>();
new_node->set_abstract(old_node->abstract());
@ -95,32 +94,29 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_kernel_info(old_node->kernel_info_ptr());
repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node);
TraceManager::EndTrace();
}
void Cloner::CloneValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TraceManager::DebugTrace(node->debug_info(), relation_);
TraceGuard trace_guard(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(GetValueNode(node));
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
repl_node_[node] = new_const;
TraceManager::EndTrace();
}
void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
TraceGuard trace_guard(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(target);
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
repl_node_[node] = new_const;
TraceManager::EndTrace();
}
void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
@ -219,7 +215,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNode
void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
TraceGuard trace_guard(func_graph->debug_info(), target_relation_);
*target_func_graph = std::make_shared<FuncGraph>();
(*target_func_graph)->set_attrs(func_graph->attrs());
(*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_;
@ -231,7 +227,6 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
TraceManager::EndTrace();
}
void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
@ -273,9 +268,8 @@ void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
}
ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(func_graph);
TraceManager::EndTrace();
CloneParameter(param, node);
if (is_add) {
func_graph->add_parameter(param);
@ -633,16 +627,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &r
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
MS_EXCEPTION_IF_NULL(func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), relation);
TraceGuard guard(func_graph->debug_info(), relation);
auto new_func_graph = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
auto &parameters = func_graph->parameters();
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
MS_EXCEPTION_IF_NULL(param);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info()));
(void)new_func_graph->add_parameter();
TraceManager::EndTrace();
});
Cloner cloner = Cloner();

View File

@ -89,7 +89,7 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
int pos_args_input_count) {
// if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
if (specialized_graph->has_vararg()) {
TraceManager::DebugTrace(
TraceGuard trace_guard(
std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
std::vector<AnfNodePtr> var_param_tuple_nodes;
var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
@ -112,7 +112,6 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
}
auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes);
(void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
TraceManager::EndTrace();
} else if (variable_args_count > 0) {
MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount()
<< " positional arguments, but " << pos_args_input_count << " were given.";
@ -181,7 +180,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
if (has_kwarg()) {
MS_EXCEPTION_IF_NULL(specialized_graph);
TraceManager::DebugTrace(
TraceGuard guard(
std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
@ -189,7 +188,6 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
MS_EXCEPTION_IF_NULL(repl_nodes);
(void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
TraceManager::EndTrace();
}
}

View File

@ -80,10 +80,14 @@ class TraceManager {
class TraceGuard {
public:
explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location);
}
explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
explicit TraceGuard(const TraceInfoPtr &trace_info) { TraceManager::DebugTrace(trace_info); }
TraceGuard(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
TraceManager::DebugTrace(debug_info, trace_info);
}
~TraceGuard() { TraceManager::EndTrace(); }
};

View File

@ -111,7 +111,7 @@ std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, So
return oss.str();
}
std::string DumpSourceLines(const AnfNodePtr node) { return DumpSourceLines(node.get()); }
std::string DumpSourceLines(const AnfNodePtr &node) { return DumpSourceLines(node.get()); }
std::string DumpSourceLines(AnfNode *node) {
if (node == nullptr) {
@ -120,6 +120,9 @@ std::string DumpSourceLines(AnfNode *node) {
}
auto info_vec = GetSourceCodeDebugInfoVec(node->debug_info());
std::ostringstream oss;
if (!info_vec.empty()) {
oss << "\n";
}
for (auto info : info_vec) {
MS_EXCEPTION_IF_NULL(info);
auto loc = info->location();
@ -134,7 +137,7 @@ std::string DumpSourceLines(AnfNode *node) {
return oss.str();
}
std::vector<std::string> GetSourceLineList(const AnfNodePtr node) {
std::vector<std::string> GetSourceLineList(const AnfNodePtr &node) {
std::vector<std::string> result;
if (node == nullptr) {
MS_LOG(WARNING) << "Node is null";
@ -155,7 +158,7 @@ std::vector<std::string> GetSourceLineList(const AnfNodePtr node) {
return result;
}
std::vector<LocationPtr> GetSourceLocationList(const AnfNodePtr node) {
std::vector<LocationPtr> GetSourceLocationList(const AnfNodePtr &node) {
std::vector<LocationPtr> result;
if (node == nullptr) {
MS_LOG(WARNING) << "Node is null";
@ -171,7 +174,7 @@ std::vector<LocationPtr> GetSourceLocationList(const AnfNodePtr node) {
return result;
}
std::string GetDebugTraceInfo(const AnfNodePtr node, bool is_debug) {
std::string GetDebugTraceInfo(const AnfNodePtr &node, bool is_debug) {
if (node == nullptr) {
MS_LOG(WARNING) << "Node is null";
return "";

View File

@ -34,14 +34,14 @@ std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLi
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine);
// Generate the call stack of python source code to a string
std::string DumpSourceLines(const AnfNodePtr node);
std::string DumpSourceLines(const AnfNodePtr &node);
std::string DumpSourceLines(AnfNode *node);
// Generate the call stack of python source code to a vector
std::vector<std::string> GetSourceLineList(const AnfNodePtr node);
std::vector<std::string> GetSourceLineList(const AnfNodePtr &node);
// Get the locations of the call stack of python source code
std::vector<LocationPtr> GetSourceLocationList(const AnfNodePtr node);
std::vector<LocationPtr> GetSourceLocationList(const AnfNodePtr &node);
// Generate the call stack of python source code with relevant trace info
std::string GetDebugTraceInfo(const AnfNodePtr node, bool is_debug = false);
std::string GetDebugTraceInfo(const AnfNodePtr &node, bool is_debug = false);
} // namespace trace
} // namespace mindspore

View File

@ -416,12 +416,20 @@ class TraceCombileLikeGraphs : public TraceInfo {
class TraceSegmentTransform : public TraceInfo {
public:
explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
MS_DECLARE_PARENT(TraceSegmentTransform, TraceInfo);
~TraceSegmentTransform() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>());
}
};
class TraceOpt : public TraceInfo {
public:
explicit TraceOpt(const DebugInfoPtr &info) : TraceInfo(info, "opt", "") {}
MS_DECLARE_PARENT(TraceOpt, TraceInfo);
~TraceOpt() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceOpt>(*shared_from_base<TraceOpt>()); }
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_