clean code

This commit is contained in:
r1chardf1d0 2024-03-06 11:37:21 +08:00
parent ccbaf161ff
commit affe33e8ab
6 changed files with 47 additions and 38 deletions

View File

@ -495,6 +495,18 @@ static void MarkBreak(Graph *g) {
}
}
std::vector<py::object> GetAllArgs(JitCompileResults *jcr) {
auto all_args = PackArgs(jcr->origin_frame_);
auto args = py::cast<py::list>(all_args[0]);
if (all_args[1].ptr() != nullptr) {
PyList_Append(args.ptr(), all_args[1].ptr()); // args + vargs
}
if (all_args[2].ptr() != nullptr) {
PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs
}
return args.cast<std::vector<py::object>>();
}
// preprocess before compile, split bytecode to sub-function
// return whether the code should be modified
static bool GraphCapture(JitCompileResults *jcr) {
@ -503,13 +515,12 @@ static bool GraphCapture(JitCompileResults *jcr) {
GraphJitConfig &conf = *jcr->conf;
auto g = GraphBuilder::Creator(jcr->origin_frame_, conf.GetBoolConfig(GraphJitConfig::kTraceFlag));
auto all_args = PackArgs(jcr->origin_frame_);
auto args = py::cast<py::list>(all_args[0]);
if (all_args[2].ptr() != nullptr) {
PyList_Append(args.ptr(), all_args[2].ptr()); // args + kwargs
}
(void)g->TraceRun(args.cast<std::vector<py::object>>());
if (conf.GetBoolConfig(GraphJitConfig::kTraceFlag)) {
auto mg = std::dynamic_pointer_cast<MindGraphBuilder>(g);
mg->FGAddInputs(GetAllArgs(jcr));
}
(void)g->TraceRun();
if (g->StackSize() > 0) {
auto block = g->PeekStack(0);
@ -939,9 +950,6 @@ std::vector<py::object> PackArgs(const PyFrameObject *frame) {
args[argi] = py::reinterpret_borrow<py::object>(PyCell_GET(cell));
}
}
if (vargs.ptr() != nullptr) {
PyList_Append(args.ptr(), vargs.ptr());
}
return {args, vargs, kwvargs};
}
@ -1151,6 +1159,9 @@ static py::object CallCompiledResults(PyThreadState *tstate, PyFrameObject *f, c
ValidateCompiledResults(c);
std::vector<py::object> packed_args = PackArgs(f);
if (packed_args[1].ptr() != nullptr) {
PyList_Append(packed_args[0].ptr(), packed_args[1].ptr());
}
py::object args = py::reinterpret_steal<py::object>(PyList_AsTuple(packed_args[0].ptr()));
py::object kwvargs = packed_args[2];
@ -1410,9 +1421,13 @@ PyObject *EvalFrame(PyThreadState *tstate, PyFrameObject *f, int exc) {
}
py::object res;
try {
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "0");
}
res = CodeHook(tstate, c, f);
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2");
if (c->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
common::SetEnv("MS_DEV_JIT_SYNTAX_LEVEL", "2");
}
} catch (py::error_already_set &e) {
MS_LOG(ERROR) << "execute failed with " << e.what() << " at "
<< std::string(py::str(reinterpret_cast<PyObject *>(f->f_code)));

View File

@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() {
// UD analyze: alive nodes analysis
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
if (!aliveLocals.empty()) {
bool stop_analyze = false;
while (!stop_analyze) {
stop_analyze = AnalyzeAliveLocals(aliveLocals);
if (stop_analyze) {
bool isStopAnalyze = false;
while (!isStopAnalyze) {
isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
if (isStopAnalyze) {
break;
}
aliveLocals = GetAliveLocals(graph_);

View File

@ -2034,7 +2034,7 @@ bool GraphBuilder::ReplaceCall(CallNode *call_node, const py::object &old_func)
}
namespace {
std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subgraph) {
std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) {
auto func_str = py::cast<std::string>(py::str(func));
std::vector<std::string> vec;
std::istringstream iss(func_str);
@ -2053,15 +2053,16 @@ std::string GetFuncGraphName(const py::object &func, const GraphBuilderPtr &subg
StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
const GraphBuilderPtr &subgraph) {
auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
InlineReason stat = InlineReason::kInline;
bool is_make_func = call_node->input(0)->GetOpcode() == MAKE_FUNCTION;
if (is_make_func) {
// inline MAKE_FUNCTION, need eliminate cell and free variable if the function is not dead local.
bool has_cell = PyTuple_GET_SIZE(subgraph->GetGraph()->GetCodeObj()->co_cellvars) != 0;
bool has_cell = PyTuple_GET_SIZE(sg->GetGraph()->GetCodeObj()->co_cellvars) != 0;
stat = has_cell ? InlineReason::kInlinePolicyDisabled : stat;
}
auto code = subgraph->GetGraph()->GetGuard();
auto code = sg->GetGraph()->GetGuard();
MS_EXCEPTION_IF_NULL(code);
code->GetGuard()->Backup();
@ -2071,18 +2072,18 @@ StopTraceReason MindGraphBuilder::BuildSubGraph(CallNode *call_node, int depth,
}
MS_LOG(INFO) << "new subgraph->TraceRun:" << py::str(func);
auto reason = subgraph->TraceRun(args);
sg->FGAddInputs(args);
auto reason = sg->TraceRun();
MS_LOG(INFO) << "new subgraph->TraceRun end:" << py::str(func);
call_node->SetSubGraph(subgraph->GetGraph());
auto sg = std::dynamic_pointer_cast<MindGraphBuilder>(subgraph);
auto sub_ret = subgraph->GetGraph()->GetRetVal();
call_node->SetSubGraph(sg->GetGraph());
auto sub_ret = sg->GetGraph()->GetRetVal();
if (sub_ret != nullptr) {
if (sub_ret->GetVobj()->GetPyObject().ptr() == nullptr ||
CheckConstPyObject(sub_ret->GetVobj()->GetPyObject().ptr())) {
call_node->SetVobj(sub_ret->GetVobj());
} else {
sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, subgraph));
sg->FGBuilder()->SetGraphName(GetFuncGraphName(func, sg));
sg->FGAddOutput();
if (sg->FGBuilder()->graph() == nullptr) {
MS_LOG(ERROR) << "subgraph trace null";
@ -2140,7 +2141,7 @@ StopTraceReason GraphBuilder::BuildSubGraph(CallNode *call_node, int depth, cons
code->GetGuard()->Backup();
MS_LOG(INFO) << "old subgraph->TraceRun";
subgraph->TraceRun(call_node->GetArgs());
subgraph->TraceRun();
call_node->SetSubGraph(subgraph->GetGraph());
if (subgraph->GetGraph()->GetRetVal() != nullptr) {
@ -2533,15 +2534,13 @@ bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *c
static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node);
StopTraceReason MindGraphBuilder::TraceRun(const std::vector<py::object> &args) {
void MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
// Add function graph inputs.
for (size_t i = 0; i < args.size(); ++i) {
MS_LOG(INFO) << "try add input: " << py::str(args[i]);
FGBuilder()->AddInput(args[i]);
MS_LOG(INFO) << "add input suc";
}
auto res = GraphBuilder::TraceRun(args);
return res;
}
void MindGraphBuilder::FGAddOutput() {
@ -2553,7 +2552,6 @@ void MindGraphBuilder::FGAddOutput() {
MS_LOG(INFO) << "add output succuss";
} else {
MS_LOG(ERROR) << "add output fail";
// TODO(xiaruijie)
}
}
}
@ -2642,8 +2640,7 @@ py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReaso
}
return FGAddNode(call_node, callable_info, args, stop_reason);
}
if (FGBuilder()->CanConstantFoldFunc(callable_info) ||
(CheckCell(callable_info) && callable->GetType() == AObject::kTypeType)) {
if (FGBuilder()->CanConstantFoldFunc(callable_info)) {
MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
JustCallAndSetRes(call_node);
*stop_reason = StopTraceReason::kNonStopTrace;
@ -3333,8 +3330,7 @@ static void EliminateCellAccess(Graph *g) {
}
}
StopTraceReason GraphBuilder::TraceRun(const std::vector<py::object> &args) {
args_ = args;
StopTraceReason GraphBuilder::TraceRun() {
current_block_ = graph_->GetCFG()->GetFirstBB();
cur_bci_ = 0;
const auto &instrs = graph_->GetCFG()->instr_pool();
@ -3387,7 +3383,7 @@ AObject *InferFuncResult(const py::object &callable, const py::object &args, con
if (g == nullptr) {
return nullptr;
}
g->TraceRun(py::cast<py::list>(args).cast<std::vector<py::object>>());
g->TraceRun();
if (clear_guard) {
Graph *graph = g->GetGraph();
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(graph->GetCodeObj()));

View File

@ -66,7 +66,7 @@ class GraphBuilder {
: std::make_shared<GraphBuilder>(r, p, co, globals);
}
virtual StopTraceReason TraceRun(const std::vector<py::object> &args);
StopTraceReason TraceRun();
virtual bool trace_flag() { return false; }
void CollectInlineInfo(CallNode *node, int depth);
@ -87,7 +87,6 @@ class GraphBuilder {
TryBlock &PopStack();
protected:
std::vector<py::object> args_; // inputs
GraphBuilder *root_;
GraphBuilder *parent_;
Graph *graph_;
@ -314,7 +313,7 @@ class MindGraphBuilder : public GraphBuilder {
}
bool trace_flag() { return true; }
mindspore::FuncGraphBuilderPtr FGBuilder() const { return fg_builder_; }
StopTraceReason TraceRun(const std::vector<py::object> &args);
void FGAddInputs(const std::vector<py::object> &args);
py::object FGAddNode(CallNode *call_node, const py::object &callable_info, const std::vector<py::object> &args,
StopTraceReason *stop_reason);
void FGAddOutput();

View File

@ -639,7 +639,7 @@ static bool CheckJitFunc(const py::object &o) {
return size > except_size && !strncmp(file + (size - except_size), except_file, except_size);
}
bool CheckCell(const py::object &callable_info) {
static bool CheckCell(const py::object &callable_info) {
PyTypeObject *cell_type = PyType_Check(callable_info.ptr()) ? reinterpret_cast<PyTypeObject *>(callable_info.ptr())
: Py_TYPE(callable_info.ptr());
if (!IsCellType<true>(cell_type)) {

View File

@ -42,7 +42,6 @@ const std::vector<std::pair<CheckFunc, std::string>> &GetFuncWhiteListFuzzyMatch
const std::string GetMindsporeNamePrimitive();
bool InferListAppend(CallNode *call_node);
bool CheckCell(const py::object &callable_info);
} // namespace pijit
} // namespace mindspore