clean code
This commit is contained in:
parent
ccbaf161ff
commit
affe33e8ab
|
@ -495,6 +495,18 @@ static void MarkBreak(Graph *g) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<py::object> GetAllArgs(JitCompileResults *jcr) {
|
||||
auto all_args = PackArgs(jcr->origin_frame_);
|
||||
auto args = py::cast<py::list>(all_args[0]);
|
||||
if (all_args[1].ptr() != nullptr) {
|
||||
PyList_Append(args.ptr(), all_args[1].ptr()); // args + vargs
|
||||
}
|
||||
if (all_args[2].ptr() != nullptr) {
|
||||
PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs
|
||||
}
|
||||
return args.cast<std::vector<py::object>>();
|
||||
}
|
||||
|
||||
// preprocess before compile, split bytecode to sub-function
|
||||
// return whether the code should be modified
|
||||
static bool GraphCapture(JitCompileResults *jcr) {
|
||||
|
@ -503,13 +515,12 @@ static bool GraphCapture(JitCompileResults *jcr) {
|
|||
GraphJitConfig &conf = *jcr->conf;
|
||||
|
||||
auto g = GraphBuilder::Creator(jcr->origin_frame_, conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
|
||||
auto all_args = PackArgs(jcr->origin_frame_);
|
||||
auto args = py::cast<py::list>(all_args[0]);
|
||||
if (all_args[2].ptr() != nullptr) {
|
||||
PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs
|
||||
}
|
||||
|
||||
(void)g->TraceRun(args.cast<std::vector<py::object>>());
|
||||
if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
|
||||
auto mg = std::dynamic_pointer_cast<MindGraphBuilder>(g);
|
||||
mg->FGAddInputs(GetAllArgs(jcr));
|
||||
}
|
||||
(void)g->TraceRun();
|
||||
|
||||
if (g->StackSize() > 0) {
|
||||
auto block = g->PeekStack(0);
|
||||
|
@ -939,9 +950,6 @@ std::vector<py::object> PackArgs(const PyFrameObject *frame) {
|
|||
args[argi] = py::reinterpret_borrow<py::object>(PyCell_GET(cell));
|
||||
}
|
||||
}
|
||||
if (vargs.ptr() != nullptr) {
|
||||
PyList_Append(args.ptr(), vargs.ptr());
|
||||
}
|
||||
return {args, vargs, kwvargs};
|
||||
}
|
||||
|
||||
|
@ -1151,6 +1159,9 @@ static py::object CallCompiledResults(PyThreadState *tstate, PyFrameObject *f, c
|
|||
ValidateCompiledResults(c);
|
||||
|
||||
std::vector<py::object> packed_args = PackArgs(f);
|
||||
if (packed_args[1].ptr() != nullptr) {
|
||||
PyList_Append(packed_args[0].ptr(), packed_args[1].ptr());
|
||||
}
|
||||
|
||||
py::object args = py::reinterpret_steal<py::object>(PyList_AsTuple(packed_args[0].ptr()));
|
||||
py::object kwvargs = packed_args[2];
|
||||
|
@ -1410,9 +1421,13 @@ PyObject *EvalFrame(PyThreadState *tstate, PyFrameObject *f, int exc) {
|
|||
}
|
||||
py::object res;
|
||||
try {
|
||||
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
|
||||
if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
|
||||
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
|
||||
}
|
||||
res = CodeHook(tstate, c, f);
|
||||
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2");
|
||||
if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
|
||||
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2");
|
||||
}
|
||||
} catch (py::error_already_set &e) {
|
||||
MS_LOG(ERROR) << "execute failed with " << e.what() << " at "
|
||||
<< std::string(py::str(reinterpret_cast<PyObject *>(f->f_code)));
|
||||
|
|
|
@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() {
|
|||
// UD analyze: alive nodes analysis
|
||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||
if (!aliveLocals.empty()) {
|
||||
bool stop_analyze = false;
|
||||
while (!stop_analyze) {
|
||||
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (stop_analyze) {
|
||||
bool isStopAnalyze = false;
|
||||
while (!isStopAnalyze) {
|
||||
isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (isStopAnalyze) {
|
||||
break;
|
||||
}
|
||||
aliveLocals = GetAliveLocals(graph_);
|
||||
|
|
|
@ -2034,7 +2034,7 @@ bool GraphBuilder::ReplaceCall(CallNode *call_node, const py::object &old_func)
|
|||
}
|
||||
|
||||
namespace {
|
||||
std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subgraph) {
|
||||
std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) {
|
||||
auto func_str = py::cast<std::string>(py::str(func));
|
||||
std::vector<std::string> vec;
|
||||
std::istringstream iss(func_str);
|
||||
|
@ -2053,15 +2053,16 @@ std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subg
|
|||
|
||||
StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
|
||||
const GraphBuilderPtr &subgraph) {
|
||||
auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
|
||||
InlineReason stat = InlineReason::kInline;
|
||||
bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
|
||||
if (is_make_func) {
|
||||
// inline MAKE_FUNCTION, need eliminate cell and free variable if the function is not dead local.
|
||||
bool has_cell = PyTuple_GET_SIZE(subgraph->GetGraph()->GetCodeObj()->co_cellvars) != 0;
|
||||
bool has_cell = PyTuple_GET_SIZE(sg->GetGraph()->GetCodeObj()->co_cellvars) != 0;
|
||||
stat = has_cell ? InlineReason::kInlinePolicyDisabled : stat;
|
||||
}
|
||||
|
||||
auto code = subgraph->GetGraph()->GetGuard();
|
||||
auto code = sg->GetGraph()->GetGuard();
|
||||
MS_EXCEPTION_IF_NULL(code);
|
||||
code->GetGuard()->Backup();
|
||||
|
||||
|
@ -2071,18 +2072,18 @@ StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth,
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "new subgraph->TraceRun:" << py::str(func);
|
||||
auto reason = subgraph->TraceRun(args);
|
||||
sg->FGAddInputs(args);
|
||||
auto reason = sg->TraceRun();
|
||||
MS_LOG(INFO) << "new subgraph->TraceRun end:" << py::str(func);
|
||||
|
||||
call_node->SetSubGraph(subgraph->GetGraph());
|
||||
auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
|
||||
auto sub_ret = subgraph->GetGraph()->GetRetVal();
|
||||
call_node->SetSubGraph(sg->GetGraph());
|
||||
auto sub_ret = sg->GetGraph()->GetRetVal();
|
||||
if (sub_ret != nullptr) {
|
||||
if (sub_ret->GetVobj()->GetPyObject().ptr() == nullptr ||
|
||||
CheckConstPyObject(sub_ret->GetVobj()->GetPyObject().ptr())) {
|
||||
call_node->SetVobj(sub_ret->GetVobj());
|
||||
} else {
|
||||
sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, subgraph));
|
||||
sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, sg));
|
||||
sg->FGAddOutput();
|
||||
if (sg->FGBuilder()->graph() == nullptr) {
|
||||
MS_LOG(ERROR) << "subgraph trace null";
|
||||
|
@ -2140,7 +2141,7 @@ StopTraceReason GraphBuilder::BuildSubGraph(CallNode *call_node, int depth, cons
|
|||
code->GetGuard()->Backup();
|
||||
|
||||
MS_LOG(INFO) << "old subgraph->TraceRun";
|
||||
subgraph->TraceRun(call_node->GetArgs());
|
||||
subgraph->TraceRun();
|
||||
|
||||
call_node->SetSubGraph(subgraph->GetGraph());
|
||||
if (subgraph->GetGraph()->GetRetVal() != nullptr) {
|
||||
|
@ -2533,15 +2534,13 @@ bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *c
|
|||
|
||||
static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node);
|
||||
|
||||
StopTraceReason MindGraphBuilder::TraceRun(const std::vector<py::object> &args) {
|
||||
void MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
|
||||
// Add function graph inputs.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
MS_LOG(INFO) << "try add input: " << py::str(args[i]);
|
||||
FGBuilder()->AddInput(args[i]);
|
||||
MS_LOG(INFO) << "add input suc";
|
||||
}
|
||||
auto res = GraphBuilder::TraceRun(args);
|
||||
return res;
|
||||
}
|
||||
|
||||
void MindGraphBuilder::FGAddOutput() {
|
||||
|
@ -2553,7 +2552,6 @@ void MindGraphBuilder::FGAddOutput() {
|
|||
MS_LOG(INFO) << "add output succuss";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "add output fail";
|
||||
// TODO(xiaruijie)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2642,8 +2640,7 @@ py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReaso
|
|||
}
|
||||
return FGAddNode(call_node, callable_info, args, stop_reason);
|
||||
}
|
||||
if (FGBuilder()->CanConstantFoldFunc(callable_info) ||
|
||||
(CheckCell(callable_info) && callable->GetType() == AObject::kTypeType)) {
|
||||
if (FGBuilder()->CanConstantFoldFunc(callable_info)) {
|
||||
MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
|
||||
JustCallAndSetRes(call_node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
|
@ -3333,8 +3330,7 @@ static void EliminateCellAccess(Graph *g) {
|
|||
}
|
||||
}
|
||||
|
||||
StopTraceReason GraphBuilder::TraceRun(const std::vector<py::object> &args) {
|
||||
args_ = args;
|
||||
StopTraceReason GraphBuilder::TraceRun() {
|
||||
current_block_ = graph_->GetCFG()->GetFirstBB();
|
||||
cur_bci_ = 0;
|
||||
const auto &instrs = graph_->GetCFG()->instr_pool();
|
||||
|
@ -3387,7 +3383,7 @@ AObject *InferFuncResult(const py::object &callable, const py::object &args, con
|
|||
if (g == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
g->TraceRun(py::cast<py::list>(args).cast<std::vector<py::object>>());
|
||||
g->TraceRun();
|
||||
if (clear_guard) {
|
||||
Graph *graph = g->GetGraph();
|
||||
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(graph->GetCodeObj()));
|
||||
|
|
|
@ -66,7 +66,7 @@ class GraphBuilder {
|
|||
: std::make_shared<GraphBuilder>(r, p, co, globals);
|
||||
}
|
||||
|
||||
virtual StopTraceReason TraceRun(const std::vector<py::object> &args);
|
||||
StopTraceReason TraceRun();
|
||||
virtual bool trace_flag() { return false; }
|
||||
|
||||
void CollectInlineInfo(CallNode *node, int depth);
|
||||
|
@ -87,7 +87,6 @@ class GraphBuilder {
|
|||
TryBlock &PopStack();
|
||||
|
||||
protected:
|
||||
std::vector<py::object> args_; // inputs
|
||||
GraphBuilder *root_;
|
||||
GraphBuilder *parent_;
|
||||
Graph *graph_;
|
||||
|
@ -314,7 +313,7 @@ class MindGraphBuilder : public GraphBuilder {
|
|||
}
|
||||
bool trace_flag() { return true; }
|
||||
mindspore::FuncGraphBuilderPtr FGBuilder() const { return fg_builder_; }
|
||||
StopTraceReason TraceRun(const std::vector<py::object> &args);
|
||||
void FGAddInputs(const std::vector<py::object> &args);
|
||||
py::object FGAddNode(CallNode *call_node, const py::object &callable_info, const std::vector<py::object> &args,
|
||||
StopTraceReason *stop_reason);
|
||||
void FGAddOutput();
|
||||
|
|
|
@ -639,7 +639,7 @@ static bool CheckJitFunc(const py::object &o) {
|
|||
return size > except_size && !strncmp(file + (size - except_size), except_file, except_size);
|
||||
}
|
||||
|
||||
bool CheckCell(const py::object &callable_info) {
|
||||
static bool CheckCell(const py::object &callable_info) {
|
||||
PyTypeObject *cell_type = PyType_Check(callable_info.ptr()) ? reinterpret_cast<PyTypeObject *>(callable_info.ptr())
|
||||
: Py_TYPE(callable_info.ptr());
|
||||
if (!IsCellType<true>(cell_type)) {
|
||||
|
|
|
@ -42,7 +42,6 @@ const std::vector<std::pair<CheckFunc, std::string>> &GetFuncWhiteListFuzzyMatch
|
|||
const std::string GetMindsporeNamePrimitive();
|
||||
|
||||
bool InferListAppend(CallNode *call_node);
|
||||
|
||||
bool CheckCell(const py::object &callable_info);
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue