Adapter code_gen for constructing graph with trace

This commit is contained in:
yujianfeng 2024-01-25 15:37:57 +08:00 committed by r1chardf1d0
parent f785fabcf2
commit 0507f3b9a3
5 changed files with 142 additions and 28 deletions

View File

@ -577,7 +577,7 @@ static bool GraphCapture(JitCompileResults *jcr) {
AObject::aobject_mem_pool_.Clear(__FILE__, __LINE__);
bool captured = !analyzer->NeedInterpret() && !conf.GetBoolConfig(GraphJitConfig::kInterpretCapturedCode);
if (captured) {
if (captured && !jcr->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
jcr->stat = JitCompileResults::GRAPH_CAPTURED;
}
return new_code.ptr() != reinterpret_cast<PyObject *>(jcr->origin_frame_->f_code);

View File

@ -23,6 +23,7 @@
#include "pipeline/jit/pi/utils/utils.h"
#include "pipeline/jit/pi/common.h"
#include "pipeline/jit/pi/external.h"
#include "pipeline/jit/pi/graph_compiler/compiler.h"
#ifndef _Py_MAKECODEUNIT
#ifdef WORDS_BIGENDIAN
@ -737,6 +738,12 @@ py::object CodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Inst
return code;
}
py::object MindCodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc,
int code_flag) const {
int flags = co_->co_flags & ~(CO_VARARGS | CO_VARKEYWORDS);
return MakeCopyCode(AttachCodeID(MakeCompiledName(py::str(co_->co_name))), argc, 0, flags | code_flag);
}
void CodeBreakGenerator::CallCapturedCode(CodeGenerator *code_gen) {
if (captured_.operations.empty()) {
return;
@ -1042,14 +1049,56 @@ void CodeBreakGenerator::CallUntrackedCode(CodeGenerator *code_gen) {
code_gen->NewInstr(RETURN_VALUE);
}
py::object MindCodeBreakGenerator::MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount,
int co_flags, bool make_graph) const {
py::str py_co_name(co_name);
PyCodeObject *new_code =
PyCode_New(co_argcount, co_kwonlyargcount, co_->co_nlocals, co_->co_stacksize, co_flags, co_->co_code,
co_->co_consts, co_->co_names, co_->co_varnames, co_->co_freevars, co_->co_cellvars, co_->co_filename,
py_co_name.ptr(), co_->co_firstlineno, co_->co_lnotab);
if (new_code == nullptr) {
throw py::error_already_set();
}
auto copy_code = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(new_code));
// Compile graph.
auto b = std::dynamic_pointer_cast<MindGraphBuilder>(builder_);
MS_EXCEPTION_IF_NULL(b);
auto func_graph = FGBuilder()->graph();
std::string phase =
py::cast<std::string>(co_->co_filename) + "_" + std::to_string(co_->co_firstlineno) + "_" + co_name;
const auto &parameters = func_graph->parameters();
py::tuple args(parameters.size());
for (size_t i = 0; i < parameters.size(); ++i) {
phase += "_" + parameters[i]->abstract()->ToString();
args[i] = *(parameters[i]->user_data<py::object>("pi_jit_py_obj"));
}
phase += ".pi_jit";
MindCompiler::CompileInfo compile_info{co_name, co_argcount, co_kwonlyargcount, co_flags};
CallableGraph callable = mindspore::pijit::MindCompiler::Compile(func_graph, args, py::dict(), phase, compile_info);
// Set NativeFunc.
auto parent = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
if (make_graph) {
parent->code->SetNativeFunc(phase, callable, nullptr);
} else {
JitCompileResults *child = getJitCompileResults(copy_code.ptr());
child->code = child->codehub->AddOptTarget(OptOption::CreateOptionByPoint(child));
child->code->SetNativeFunc(phase, callable, nullptr);
child->stat = CodeExtra::GRAPH_CALLABLE;
child->conf = parent->conf;
child->tbs = parent->tbs;
}
return copy_code;
}
py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
if (make_graph) {
// all parameters is graph supported
captured_.inputs.clear();
captured_.outputs.clear();
interpret_.operations = std::move(captured_.operations);
co_name = MakeCompiledName(co_name);
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
return MakeCopyCode(AttachCodeID(co_name), co_->co_argcount + co_->co_kwonlyargcount, 0, co_->co_flags, true);
}
CodeGenerator code_gen(&interpret_);
@ -1061,13 +1110,11 @@ py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
code_gen.Build();
CallCapturedCode(&code_gen);
FixInterpretOuput(&code_gen);
// ... handle side effects
CallUntrackedCode(&code_gen);
MakeReturn(&code_gen);
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
if (make_graph) {
co_name = MakeCompiledName(co_name);
}
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
int nlocals = code_gen.GetLocalsMap().size();
@ -1086,14 +1133,6 @@ py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
code_gen.EraseUnusedInstr();
py::object result = CodeGenerator::Transform(code_gen.GetCode());
if (make_graph) {
JitCompileResults *child = getJitCompileResults(result.ptr());
MS_LOG(INFO) << "child->fg = " << FGBuilder()->graph();
// child->fg = FGBuilder()->graph();
child->stat = CodeExtra::GRAPH_CAPTURED;
child->conf = jcr->conf;
child->tbs = jcr->tbs;
}
return result;
}

View File

@ -201,7 +201,7 @@ class CodeBreakGenerator {
void BuildGraphParameters(const std::unordered_map<ValueNode *, int> &locals, GraphParameterBuilder *);
// rebuild captured nodes to bytecode, build parameters load operations
py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, int flag) const;
virtual py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, int flag) const;
// make call operations of graph, build parameters load operations
void CallCapturedCode(CodeGenerator *code_gen);
@ -272,7 +272,13 @@ class MindCodeBreakGenerator : public CodeBreakGenerator {
return std::dynamic_pointer_cast<MindGraphBuilder>(builder_)->FGBuilder();
}
py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc, int code_flag) const override;
private:
py::object MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount, int co_flags,
bool make_graph = false) const;
GraphBuilderPtr builder_;
};
// add a key and value to py::dict, check key conflict or rename the key

View File

@ -98,11 +98,11 @@ py::tuple EliminateSelf(const py::tuple &args, const std::string &name) {
return args;
}
py::tuple EliminateInvalidArgs(const py::tuple &args, const PyCodeObject &code, bool enable_tuple_broaden) {
py::tuple EliminateInvalidArgs(const py::tuple &args, int co_flags, bool enable_tuple_broaden) {
py::list new_args;
for (size_t idx = 0; idx < args.size(); idx++) {
if (IsValidRunArg(args[idx], enable_tuple_broaden)) {
if ((idx < (args.size() - 1) || (code.co_flags & CO_VARKEYWORDS) == 0) && py::isinstance<py::dict>(args[idx])) {
if ((idx < (args.size() - 1) || (co_flags & CO_VARKEYWORDS) == 0) && py::isinstance<py::dict>(args[idx])) {
new_args.append(py::reinterpret_steal<py::tuple>(PyDict_Values(args[idx].ptr())));
} else {
new_args.append(args[idx]);
@ -112,19 +112,19 @@ py::tuple EliminateInvalidArgs(const py::tuple &args, const PyCodeObject &code,
return py::cast<py::tuple>(new_args);
}
py::tuple ExpandVariableArgs(const py::tuple &args, const PyCodeObject &code) {
if ((code.co_flags & CO_VARARGS) == 0x0) {
py::tuple ExpandVariableArgs(const py::tuple &args, int co_flags, int co_argcount) {
if ((co_flags & CO_VARARGS) == 0x0) {
return args;
}
py::tuple var_args = py::cast<py::tuple>(args[code.co_argcount]);
py::tuple var_args = py::cast<py::tuple>(args[co_argcount]);
py::list new_args;
for (int index = 0; index < code.co_argcount; index++) {
for (int index = 0; index < co_argcount; index++) {
new_args.append(args[index]);
}
for (const auto &var_arg : var_args) {
new_args.append(var_arg);
}
for (size_t index = code.co_argcount + 1; index < args.size(); index++) {
for (size_t index = co_argcount + 1; index < args.size(); index++) {
new_args.append(args[index]);
}
return py::cast<py::tuple>(new_args);
@ -144,12 +144,12 @@ CallableGraph Compiler::Compile(const PyFunctionObject &func, const PyFrameObjec
"Excepted nullptr or a Dict Object for run kwargs.");
py::tuple tuple = MergeAllArgments(args, kwargs);
tuple = ExpandVariableArgs(tuple, *code);
tuple = ExpandVariableArgs(tuple, code->co_flags, code->co_argcount);
std::string name = py::cast<std::string>(code->co_name);
tuple = EliminateSelf(tuple, name);
tuple = EliminateStubTensor(tuple);
MarkArgmentMutable(tuple);
tuple = EliminateInvalidArgs(tuple, *code, enable_tuple_broaden);
tuple = EliminateInvalidArgs(tuple, code->co_flags, enable_tuple_broaden);
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(graph_executor);
py::object ret = graph_executor->Run(tuple, py::str(phase));
@ -188,12 +188,65 @@ CallableGraph Compiler::Compile(const PyFunctionObject &func, const PyFrameObjec
if (graph == nullptr) {
return nullptr;
}
args = ExpandVariableArgs(args, *code);
args = ExpandVariableArgs(args, code->co_flags, code->co_argcount);
args = EliminateSelf(args, name);
MarkArgmentMutable(args);
(void)graph_executor->CompileInner(graph, args, kwargs, phase, true);
return callable;
}
CallableGraph MindCompiler::Compile(const FuncGraphPtr &func_graph, const py::tuple &args, const py::dict &kwargs,
const std::string &phase, const CompileInfo &compile_info) {
MS_EXCEPTION_IF_CHECK_FAIL(!phase.empty(),
"Phase name should not be empty for function " + compile_info.co_name_ + ".");
CallableGraph callable = [compile_info, phase](
PyObject *args, PyObject *kwargs) -> PyObject * {
MS_EXCEPTION_IF_CHECK_FAIL(PyTuple_Check(args), "Excepted a Tuple Object for run args.");
MS_EXCEPTION_IF_CHECK_FAIL(((kwargs == nullptr) || PyDict_Check(kwargs)),
"Excepted nullptr or a Dict Object for run kwargs.");
py::tuple tuple = MergeAllArgments(args, kwargs);
tuple = ExpandVariableArgs(tuple, compile_info.co_flags_, compile_info.co_argcount_);
tuple = EliminateSelf(tuple, compile_info.co_name_);
tuple = EliminateStubTensor(tuple);
MarkArgmentMutable(tuple);
tuple = EliminateInvalidArgs(tuple, compile_info.co_flags_, true);
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(graph_executor);
py::object ret = graph_executor->Run(tuple, py::str(phase));
int mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
auto executor = pynative::PyNativeExecutor::GetInstance();
if (mode == kPynativeMode && executor->grad_flag()) {
executor->grad_executor()->jit()->set_graph_phase(phase);
executor->GradJit(ret, tuple);
}
FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(phase);
MS_EXCEPTION_IF_NULL(ms_func_graph);
if (ms_func_graph->modify_output()) {
ret = py::cast<py::tuple>(ret)[0];
}
ret = python_adapter::CallPyFn("mindspore.common.api", "_convert_python_data", ret);
ret.inc_ref();
return ret.ptr();
};
auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
if (graph_executor->HasCompiled(phase)) {
return callable;
}
if (func_graph == nullptr) {
return nullptr;
}
py::tuple new_arg = EliminateStubTensor(args);
new_arg = ExpandVariableArgs(new_arg, compile_info.co_flags_, compile_info.co_argcount_);
new_arg = EliminateSelf(new_arg, compile_info.co_name_);
MarkArgmentMutable(new_arg);
(void)graph_executor->CompileInner(func_graph, args, kwargs, phase, true);
return callable;
}
} // namespace pijit
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <functional>
#include <string>
#include "include/common/utils/python_adapter.h"
#include "pipeline/jit/pi/common.h"
namespace mindspore {
namespace pijit {
@ -34,6 +35,21 @@ class Compiler {
private:
Compiler() = default;
};
class MindCompiler {
public:
struct CompileInfo {
std::string co_name_;
int co_argcount_;
int co_kwonlyargcount_;
int co_flags_;
};
static CallableGraph Compile(const FuncGraphPtr &func_graph, const py::tuple &args, const py::dict &kwargs,
const std::string &phase, const CompileInfo &compile_info);
private:
MindCompiler() = default;
};
} // namespace pijit
} // namespace mindspore