Adapter code_gen for constructing graph with trace
This commit is contained in:
parent
f785fabcf2
commit
0507f3b9a3
|
@ -577,7 +577,7 @@ static bool GraphCapture(JitCompileResults *jcr) {
|
|||
AObject::aobject_mem_pool_.Clear(__FILE__, __LINE__);
|
||||
|
||||
bool captured = !analyzer->NeedInterpret() && !conf.GetBoolConfig(GraphJitConfig::kInterpretCapturedCode);
|
||||
if (captured) {
|
||||
if (captured && !jcr->conf->GetBoolConfig(GraphJitConfig::kTraceFlag)) {
|
||||
jcr->stat = JitCompileResults::GRAPH_CAPTURED;
|
||||
}
|
||||
return new_code.ptr() != reinterpret_cast<PyObject *>(jcr->origin_frame_->f_code);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "pipeline/jit/pi/utils/utils.h"
|
||||
#include "pipeline/jit/pi/common.h"
|
||||
#include "pipeline/jit/pi/external.h"
|
||||
#include "pipeline/jit/pi/graph_compiler/compiler.h"
|
||||
|
||||
#ifndef _Py_MAKECODEUNIT
|
||||
#ifdef WORDS_BIGENDIAN
|
||||
|
@ -737,6 +738,12 @@ py::object CodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Inst
|
|||
return code;
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc,
|
||||
int code_flag) const {
|
||||
int flags = co_->co_flags & ~(CO_VARARGS | CO_VARKEYWORDS);
|
||||
return MakeCopyCode(AttachCodeID(MakeCompiledName(py::str(co_->co_name))), argc, 0, flags | code_flag);
|
||||
}
|
||||
|
||||
void CodeBreakGenerator::CallCapturedCode(CodeGenerator *code_gen) {
|
||||
if (captured_.operations.empty()) {
|
||||
return;
|
||||
|
@ -1042,14 +1049,56 @@ void CodeBreakGenerator::CallUntrackedCode(CodeGenerator *code_gen) {
|
|||
code_gen->NewInstr(RETURN_VALUE);
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount,
|
||||
int co_flags, bool make_graph) const {
|
||||
py::str py_co_name(co_name);
|
||||
PyCodeObject *new_code =
|
||||
PyCode_New(co_argcount, co_kwonlyargcount, co_->co_nlocals, co_->co_stacksize, co_flags, co_->co_code,
|
||||
co_->co_consts, co_->co_names, co_->co_varnames, co_->co_freevars, co_->co_cellvars, co_->co_filename,
|
||||
py_co_name.ptr(), co_->co_firstlineno, co_->co_lnotab);
|
||||
if (new_code == nullptr) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
auto copy_code = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(new_code));
|
||||
// Compile graph.
|
||||
auto b = std::dynamic_pointer_cast<MindGraphBuilder>(builder_);
|
||||
MS_EXCEPTION_IF_NULL(b);
|
||||
auto func_graph = FGBuilder()->graph();
|
||||
std::string phase =
|
||||
py::cast<std::string>(co_->co_filename) + "_" + std::to_string(co_->co_firstlineno) + "_" + co_name;
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
py::tuple args(parameters.size());
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
phase += "_" + parameters[i]->abstract()->ToString();
|
||||
args[i] = *(parameters[i]->user_data<py::object>("pi_jit_py_obj"));
|
||||
}
|
||||
phase += ".pi_jit";
|
||||
MindCompiler::CompileInfo compile_info{co_name, co_argcount, co_kwonlyargcount, co_flags};
|
||||
CallableGraph callable = mindspore::pijit::MindCompiler::Compile(func_graph, args, py::dict(), phase, compile_info);
|
||||
// Set NativeFunc.
|
||||
auto parent = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
if (make_graph) {
|
||||
parent->code->SetNativeFunc(phase, callable, nullptr);
|
||||
} else {
|
||||
JitCompileResults *child = getJitCompileResults(copy_code.ptr());
|
||||
child->code = child->codehub->AddOptTarget(OptOption::CreateOptionByPoint(child));
|
||||
child->code->SetNativeFunc(phase, callable, nullptr);
|
||||
child->stat = CodeExtra::GRAPH_CALLABLE;
|
||||
child->conf = parent->conf;
|
||||
child->tbs = parent->tbs;
|
||||
}
|
||||
|
||||
return copy_code;
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
|
||||
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
|
||||
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
|
||||
if (make_graph) {
|
||||
// all parameters is graph supported
|
||||
captured_.inputs.clear();
|
||||
captured_.outputs.clear();
|
||||
interpret_.operations = std::move(captured_.operations);
|
||||
co_name = MakeCompiledName(co_name);
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
return MakeCopyCode(AttachCodeID(co_name), co_->co_argcount + co_->co_kwonlyargcount, 0, co_->co_flags, true);
|
||||
}
|
||||
|
||||
CodeGenerator code_gen(&interpret_);
|
||||
|
@ -1061,13 +1110,11 @@ py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
|
|||
code_gen.Build();
|
||||
|
||||
CallCapturedCode(&code_gen);
|
||||
FixInterpretOuput(&code_gen);
|
||||
// ... handle side effects
|
||||
CallUntrackedCode(&code_gen);
|
||||
MakeReturn(&code_gen);
|
||||
|
||||
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
|
||||
if (make_graph) {
|
||||
co_name = MakeCompiledName(co_name);
|
||||
}
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
|
||||
int nlocals = code_gen.GetLocalsMap().size();
|
||||
|
@ -1086,14 +1133,6 @@ py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
|
|||
|
||||
code_gen.EraseUnusedInstr();
|
||||
py::object result = CodeGenerator::Transform(code_gen.GetCode());
|
||||
if (make_graph) {
|
||||
JitCompileResults *child = getJitCompileResults(result.ptr());
|
||||
MS_LOG(INFO) << "child->fg = " << FGBuilder()->graph();
|
||||
// child->fg = FGBuilder()->graph();
|
||||
child->stat = CodeExtra::GRAPH_CAPTURED;
|
||||
child->conf = jcr->conf;
|
||||
child->tbs = jcr->tbs;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -201,7 +201,7 @@ class CodeBreakGenerator {
|
|||
void BuildGraphParameters(const std::unordered_map<ValueNode *, int> &locals, GraphParameterBuilder *);
|
||||
|
||||
// rebuild captured nodes to bytecode, build parameters load operations
|
||||
py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, int flag) const;
|
||||
virtual py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, int flag) const;
|
||||
|
||||
// make call operations of graph, build parameters load operations
|
||||
void CallCapturedCode(CodeGenerator *code_gen);
|
||||
|
@ -272,7 +272,13 @@ class MindCodeBreakGenerator : public CodeBreakGenerator {
|
|||
return std::dynamic_pointer_cast<MindGraphBuilder>(builder_)->FGBuilder();
|
||||
}
|
||||
|
||||
|
||||
py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc, int code_flag) const override;
|
||||
|
||||
private:
|
||||
py::object MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount, int co_flags,
|
||||
bool make_graph = false) const;
|
||||
|
||||
GraphBuilderPtr builder_;
|
||||
};
|
||||
// add a key and value to py::dict, check key conflict or rename the key
|
||||
|
|
|
@ -98,11 +98,11 @@ py::tuple EliminateSelf(const py::tuple &args, const std::string &name) {
|
|||
return args;
|
||||
}
|
||||
|
||||
py::tuple EliminateInvalidArgs(const py::tuple &args, const PyCodeObject &code, bool enable_tuple_broaden) {
|
||||
py::tuple EliminateInvalidArgs(const py::tuple &args, int co_flags, bool enable_tuple_broaden) {
|
||||
py::list new_args;
|
||||
for (size_t idx = 0; idx < args.size(); idx++) {
|
||||
if (IsValidRunArg(args[idx], enable_tuple_broaden)) {
|
||||
if ((idx < (args.size() - 1) || (code.co_flags & CO_VARKEYWORDS) == 0) && py::isinstance<py::dict>(args[idx])) {
|
||||
if ((idx < (args.size() - 1) || (co_flags & CO_VARKEYWORDS) == 0) && py::isinstance<py::dict>(args[idx])) {
|
||||
new_args.append(py::reinterpret_steal<py::tuple>(PyDict_Values(args[idx].ptr())));
|
||||
} else {
|
||||
new_args.append(args[idx]);
|
||||
|
@ -112,19 +112,19 @@ py::tuple EliminateInvalidArgs(const py::tuple &args, const PyCodeObject &code,
|
|||
return py::cast<py::tuple>(new_args);
|
||||
}
|
||||
|
||||
py::tuple ExpandVariableArgs(const py::tuple &args, const PyCodeObject &code) {
|
||||
if ((code.co_flags & CO_VARARGS) == 0x0) {
|
||||
py::tuple ExpandVariableArgs(const py::tuple &args, int co_flags, int co_argcount) {
|
||||
if ((co_flags & CO_VARARGS) == 0x0) {
|
||||
return args;
|
||||
}
|
||||
py::tuple var_args = py::cast<py::tuple>(args[code.co_argcount]);
|
||||
py::tuple var_args = py::cast<py::tuple>(args[co_argcount]);
|
||||
py::list new_args;
|
||||
for (int index = 0; index < code.co_argcount; index++) {
|
||||
for (int index = 0; index < co_argcount; index++) {
|
||||
new_args.append(args[index]);
|
||||
}
|
||||
for (const auto &var_arg : var_args) {
|
||||
new_args.append(var_arg);
|
||||
}
|
||||
for (size_t index = code.co_argcount + 1; index < args.size(); index++) {
|
||||
for (size_t index = co_argcount + 1; index < args.size(); index++) {
|
||||
new_args.append(args[index]);
|
||||
}
|
||||
return py::cast<py::tuple>(new_args);
|
||||
|
@ -144,12 +144,12 @@ CallableGraph Compiler::Compile(const PyFunctionObject &func, const PyFrameObjec
|
|||
"Excepted nullptr or a Dict Object for run kwargs.");
|
||||
|
||||
py::tuple tuple = MergeAllArgments(args, kwargs);
|
||||
tuple = ExpandVariableArgs(tuple, *code);
|
||||
tuple = ExpandVariableArgs(tuple, code->co_flags, code->co_argcount);
|
||||
std::string name = py::cast<std::string>(code->co_name);
|
||||
tuple = EliminateSelf(tuple, name);
|
||||
tuple = EliminateStubTensor(tuple);
|
||||
MarkArgmentMutable(tuple);
|
||||
tuple = EliminateInvalidArgs(tuple, *code, enable_tuple_broaden);
|
||||
tuple = EliminateInvalidArgs(tuple, code->co_flags, enable_tuple_broaden);
|
||||
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(graph_executor);
|
||||
py::object ret = graph_executor->Run(tuple, py::str(phase));
|
||||
|
@ -188,12 +188,65 @@ CallableGraph Compiler::Compile(const PyFunctionObject &func, const PyFrameObjec
|
|||
if (graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
args = ExpandVariableArgs(args, *code);
|
||||
args = ExpandVariableArgs(args, code->co_flags, code->co_argcount);
|
||||
args = EliminateSelf(args, name);
|
||||
MarkArgmentMutable(args);
|
||||
(void)graph_executor->CompileInner(graph, args, kwargs, phase, true);
|
||||
|
||||
return callable;
|
||||
}
|
||||
|
||||
CallableGraph MindCompiler::Compile(const FuncGraphPtr &func_graph, const py::tuple &args, const py::dict &kwargs,
|
||||
const std::string &phase, const CompileInfo &compile_info) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(!phase.empty(),
|
||||
"Phase name should not be empty for function " + compile_info.co_name_ + ".");
|
||||
|
||||
CallableGraph callable = [compile_info, phase](
|
||||
PyObject *args, PyObject *kwargs) -> PyObject * {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(PyTuple_Check(args), "Excepted a Tuple Object for run args.");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(((kwargs == nullptr) || PyDict_Check(kwargs)),
|
||||
"Excepted nullptr or a Dict Object for run kwargs.");
|
||||
|
||||
py::tuple tuple = MergeAllArgments(args, kwargs);
|
||||
tuple = ExpandVariableArgs(tuple, compile_info.co_flags_, compile_info.co_argcount_);
|
||||
tuple = EliminateSelf(tuple, compile_info.co_name_);
|
||||
tuple = EliminateStubTensor(tuple);
|
||||
MarkArgmentMutable(tuple);
|
||||
tuple = EliminateInvalidArgs(tuple, compile_info.co_flags_, true);
|
||||
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(graph_executor);
|
||||
py::object ret = graph_executor->Run(tuple, py::str(phase));
|
||||
int mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
auto executor = pynative::PyNativeExecutor::GetInstance();
|
||||
if (mode == kPynativeMode && executor->grad_flag()) {
|
||||
executor->grad_executor()->jit()->set_graph_phase(phase);
|
||||
executor->GradJit(ret, tuple);
|
||||
}
|
||||
FuncGraphPtr ms_func_graph = graph_executor->GetFuncGraph(phase);
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
if (ms_func_graph->modify_output()) {
|
||||
ret = py::cast<py::tuple>(ret)[0];
|
||||
}
|
||||
ret = python_adapter::CallPyFn("mindspore.common.api", "_convert_python_data", ret);
|
||||
ret.inc_ref();
|
||||
return ret.ptr();
|
||||
};
|
||||
|
||||
auto graph_executor = mindspore::pipeline::GraphExecutorPy::GetInstance();
|
||||
if (graph_executor->HasCompiled(phase)) {
|
||||
return callable;
|
||||
}
|
||||
|
||||
if (func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
py::tuple new_arg = EliminateStubTensor(args);
|
||||
new_arg = ExpandVariableArgs(new_arg, compile_info.co_flags_, compile_info.co_argcount_);
|
||||
new_arg = EliminateSelf(new_arg, compile_info.co_name_);
|
||||
MarkArgmentMutable(new_arg);
|
||||
(void)graph_executor->CompileInner(func_graph, args, kwargs, phase, true);
|
||||
|
||||
return callable;
|
||||
}
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <functional>
|
||||
#include <string>
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "pipeline/jit/pi/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pijit {
|
||||
|
@ -34,6 +35,21 @@ class Compiler {
|
|||
private:
|
||||
Compiler() = default;
|
||||
};
|
||||
|
||||
class MindCompiler {
|
||||
public:
|
||||
struct CompileInfo {
|
||||
std::string co_name_;
|
||||
int co_argcount_;
|
||||
int co_kwonlyargcount_;
|
||||
int co_flags_;
|
||||
};
|
||||
static CallableGraph Compile(const FuncGraphPtr &func_graph, const py::tuple &args, const py::dict &kwargs,
|
||||
const std::string &phase, const CompileInfo &compile_info);
|
||||
|
||||
private:
|
||||
MindCompiler() = default;
|
||||
};
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue