diff --git a/mindspore/ccsrc/pipeline/jit/pi/common.cc b/mindspore/ccsrc/pipeline/jit/pi/common.cc index f4aac3aa5f9..fb3fe2d0306 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/common.cc +++ b/mindspore/ccsrc/pipeline/jit/pi/common.cc @@ -495,6 +495,18 @@ static void MarkBreak(Graph *g) { } } +std::vector GetAllArgs(JitCompileResults *jcr) { + auto all_args = PackArgs(jcr->origin_frame_); + auto args = py::cast(all_args[0]); + if (all_args[1].ptr() != nullptr) { + PyList_Append(args.ptr(), all_args[1].ptr()); // args + vargs + } + if (all_args[2].ptr() != nullptr) { + PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs + } + return args.cast>(); +} + // preprocess before compile, split bytecode to sub-function // return whether the code should be modified static bool GraphCapture(JitCompileResults *jcr) { @@ -503,13 +515,12 @@ static bool GraphCapture(JitCompileResults *jcr) { GraphJitConfig &conf = *jcr->conf; auto g = GraphBuilder::Creator(jcr->origin_frame_, conf.GetBoolConfig(GraphJitConfig::kTraceFlag)); - auto all_args = PackArgs(jcr->origin_frame_); - auto args = py::cast(all_args[0]); - if (all_args[2].ptr() != nullptr) { - PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs - } - (void)g->TraceRun(args.cast>()); + if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) { + auto mg = std::dynamic_pointer_cast(g); + mg->FGAddInputs(GetAllArgs(jcr)); + } + (void)g->TraceRun(); if (g->StackSize() > 0) { auto block = g->PeekStack(0); @@ -939,9 +950,6 @@ std::vector PackArgs(const PyFrameObject *frame) { args[argi] = py::reinterpret_borrow(PyCell_GET(cell)); } } - if (vargs.ptr() != nullptr) { - PyList_Append(args.ptr(), vargs.ptr()); - } return {args, vargs, kwvargs}; } @@ -1151,6 +1159,9 @@ static py::object CallCompiledResults(PyThreadState *tstate, PyFrameObject *f, c ValidateCompiledResults(c); std::vector packed_args = PackArgs(f); + if (packed_args[1].ptr() != nullptr) { + PyList_Append(packed_args[0].ptr(), packed_args[1].ptr()); + } py::object args = py::reinterpret_steal(PyList_AsTuple(packed_args[0].ptr())); py::object kwvargs = packed_args[2]; @@ -1410,9 +1421,13 @@ PyObject *EvalFrame(PyThreadState *tstate, PyFrameObject *f, int exc) { } py::object res; try { - common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0"); + if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) { + common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0"); + } res = CodeHook(tstate, c, f); - common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2"); + if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) { + common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2"); + } } catch (py::error_already_set &e) { MS_LOG(ERROR) << "execute failed with " << e.what() << " at " << std::string(py::str(reinterpret_cast(f->f_code))); diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_analyzer.cc b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_analyzer.cc index 2c0b3ece8d7..d733bfcd1ef 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_analyzer.cc +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_analyzer.cc @@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() { // UD analyze: alive nodes analysis std::vector aliveLocals = GetAliveLocals(graph_); if (!aliveLocals.empty()) { - bool stop_analyze = false; - while (!stop_analyze) { - stop_analyze = AnalyzeAliveLocals(aliveLocals); - if (stop_analyze) { + bool isStopAnalyze = false; + while (!isStopAnalyze) { + isStopAnalyze = AnalyzeAliveLocals(aliveLocals); + if (isStopAnalyze) { break; } aliveLocals = GetAliveLocals(graph_); diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.cc b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.cc index a5b1063b1fc..8f1efa06f96 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.cc +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.cc @@ -2034,7 +2034,7 @@ bool GraphBuilder::ReplaceCall(CallNode *call_node, const py::object &old_func) } namespace { -std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subgraph) { +std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) { auto func_str = py::cast(py::str(func)); std::vector vec; std::istringstream iss(func_str); @@ -2053,15 +2053,16 @@ std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subg StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func, const GraphBuilderPtr &subgraph) { + auto sg = std::dynamic_pointer_cast(subgraph); InlineReason stat = InlineReason::kInline; bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION; if (is_make_func) { // inline MAKE_FUNCTION, need eliminate cell and free variable if the function is not dead local. - bool has_cell = PyTuple_GET_SIZE(subgraph->GetGraph()->GetCodeObj()->co_cellvars) != 0; + bool has_cell = PyTuple_GET_SIZE(sg->GetGraph()->GetCodeObj()->co_cellvars) != 0; stat = has_cell ? InlineReason::kInlinePolicyDisabled : stat; } - auto code = subgraph->GetGraph()->GetGuard(); + auto code = sg->GetGraph()->GetGuard(); MS_EXCEPTION_IF_NULL(code); code->GetGuard()->Backup(); @@ -2071,18 +2072,18 @@ StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, } MS_LOG(INFO) << "new subgraph->TraceRun:" << py::str(func); - auto reason = subgraph->TraceRun(args); + sg->FGAddInputs(args); + auto reason = sg->TraceRun(); MS_LOG(INFO) << "new subgraph->TraceRun end:" << py::str(func); - call_node->SetSubGraph(subgraph->GetGraph()); - auto sg = std::dynamic_pointer_cast(subgraph); - auto sub_ret = subgraph->GetGraph()->GetRetVal(); + call_node->SetSubGraph(sg->GetGraph()); + auto sub_ret = sg->GetGraph()->GetRetVal(); if (sub_ret != nullptr) { if (sub_ret->GetVobj()->GetPyObject().ptr() == nullptr || CheckConstPyObject(sub_ret->GetVobj()->GetPyObject().ptr())) { call_node->SetVobj(sub_ret->GetVobj()); } else { - sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, subgraph)); + sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, sg)); sg->FGAddOutput(); if (sg->FGBuilder()->graph() == nullptr) { MS_LOG(ERROR) << "subgraph trace null"; @@ -2140,7 +2141,7 @@ StopTraceReason GraphBuilder::BuildSubGraph(CallNode *call_node, int depth, cons code->GetGuard()->Backup(); MS_LOG(INFO) << "old subgraph->TraceRun"; - subgraph->TraceRun(call_node->GetArgs()); + subgraph->TraceRun(); call_node->SetSubGraph(subgraph->GetGraph()); if (subgraph->GetGraph()->GetRetVal() != nullptr) { @@ -2533,15 +2534,13 @@ bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *c static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node); -StopTraceReason MindGraphBuilder::TraceRun(const std::vector &args) { +void MindGraphBuilder::FGAddInputs(const std::vector &args) { // Add function graph inputs. for (size_t i = 0; i < args.size(); ++i) { MS_LOG(INFO) << "try add input: " << py::str(args[i]); FGBuilder()->AddInput(args[i]); MS_LOG(INFO) << "add input suc"; } - auto res = GraphBuilder::TraceRun(args); - return res; } void MindGraphBuilder::FGAddOutput() { @@ -2553,7 +2552,6 @@ void MindGraphBuilder::FGAddOutput() { MS_LOG(INFO) << "add output succuss"; } else { MS_LOG(ERROR) << "add output fail"; - // TODO(xiaruijie) } } } @@ -2642,8 +2640,7 @@ py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReaso } return FGAddNode(call_node, callable_info, args, stop_reason); } - if (FGBuilder()->CanConstantFoldFunc(callable_info) || - (CheckCell(callable_info) && callable->GetType() == AObject::kTypeType)) { + if (FGBuilder()->CanConstantFoldFunc(callable_info)) { MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info); JustCallAndSetRes(call_node); *stop_reason = StopTraceReason::kNonStopTrace; @@ -3333,8 +3330,7 @@ static void EliminateCellAccess(Graph *g) { } } -StopTraceReason GraphBuilder::TraceRun(const std::vector &args) { - args_ = args; +StopTraceReason GraphBuilder::TraceRun() { current_block_ = graph_->GetCFG()->GetFirstBB(); cur_bci_ = 0; const auto &instrs = graph_->GetCFG()->instr_pool(); @@ -3387,7 +3383,7 @@ AObject *InferFuncResult(const py::object &callable, const py::object &args, con if (g == nullptr) { return nullptr; } - g->TraceRun(py::cast(args).cast>()); + g->TraceRun(); if (clear_guard) { Graph *graph = g->GetGraph(); auto jcr = getJitCompileResults(reinterpret_cast(graph->GetCodeObj())); diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.h b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.h index 7224c980f84..35f4fa92cf2 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.h +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/graph_build.h @@ -66,7 +66,7 @@ class GraphBuilder { : std::make_shared(r, p, co, globals); } - virtual StopTraceReason TraceRun(const std::vector &args); + StopTraceReason TraceRun(); virtual bool trace_flag() { return false; } void CollectInlineInfo(CallNode *node, int depth); @@ -87,7 +87,6 @@ class GraphBuilder { TryBlock &PopStack(); protected: - std::vector args_; // inputs GraphBuilder *root_; GraphBuilder *parent_; Graph *graph_; @@ -314,7 +313,7 @@ class MindGraphBuilder : public GraphBuilder { } bool trace_flag() { return true; } mindspore::FuncGraphBuilderPtr FGBuilder() const { return fg_builder_; } - StopTraceReason TraceRun(const std::vector &args); + void FGAddInputs(const std::vector &args); py::object FGAddNode(CallNode *call_node, const py::object &callable_info, const std::vector &args, StopTraceReason *stop_reason); void FGAddOutput(); diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.cc b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.cc index ee2c5558cea..f0cee2cfe91 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.cc +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.cc @@ -639,7 +639,7 @@ static bool CheckJitFunc(const py::object &o) { return size > except_size && !strncmp(file + (size - except_size), except_file, except_size); } -bool CheckCell(const py::object &callable_info) { +static bool CheckCell(const py::object &callable_info) { PyTypeObject *cell_type = PyType_Check(callable_info.ptr()) ? reinterpret_cast(callable_info.ptr()) : Py_TYPE(callable_info.ptr()); if (!IsCellType(cell_type)) { diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.h b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.h index a33f9b83dd9..b684efb43ec 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.h +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_capture/special_func_infer.h @@ -42,7 +42,6 @@ const std::vector> &GetFuncWhiteListFuzzyMatch const std::string GetMindsporeNamePrimitive(); bool InferListAppend(CallNode *call_node); - bool CheckCell(const py::object &callable_info); } // namespace pijit } // namespace mindspore