clean code
This commit is contained in:
parent
affe33e8ab
commit
4556e93599
|
@ -738,12 +738,6 @@ py::object CodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Inst
|
|||
return code;
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc,
|
||||
int code_flag) const {
|
||||
int flags = co_->co_flags & ~(CO_VARARGS | CO_VARKEYWORDS);
|
||||
return MakeCopyCode(AttachCodeID(MakeCompiledName(py::str(co_->co_name))), argc, 0, flags | code_flag);
|
||||
}
|
||||
|
||||
void CodeBreakGenerator::CallCapturedCode(CodeGenerator *code_gen) {
|
||||
if (captured_.operations.empty()) {
|
||||
return;
|
||||
|
@ -1049,101 +1043,6 @@ void CodeBreakGenerator::CallUntrackedCode(CodeGenerator *code_gen) {
|
|||
code_gen->NewInstr(RETURN_VALUE);
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount,
|
||||
int co_flags, bool make_graph) const {
|
||||
py::str py_co_name(co_name);
|
||||
PyCodeObject *new_code =
|
||||
PyCode_New(co_argcount, co_kwonlyargcount, co_->co_nlocals, co_->co_stacksize, co_flags, co_->co_code,
|
||||
co_->co_consts, co_->co_names, co_->co_varnames, co_->co_freevars, co_->co_cellvars, co_->co_filename,
|
||||
py_co_name.ptr(), co_->co_firstlineno, co_->co_lnotab);
|
||||
if (new_code == nullptr) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
auto copy_code = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(new_code));
|
||||
// Compile graph.
|
||||
auto b = std::dynamic_pointer_cast<MindGraphBuilder>(builder_);
|
||||
MS_EXCEPTION_IF_NULL(b);
|
||||
auto func_graph = FGBuilder()->graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get function graph from function graph builder failed.";
|
||||
}
|
||||
std::string phase =
|
||||
py::cast<std::string>(co_->co_filename) + "_" + std::to_string(co_->co_firstlineno) + "_" + co_name;
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
py::tuple args(parameters.size() - func_graph->fv_param_count());
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
auto para = parameters[i]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->has_default()) {
|
||||
continue;
|
||||
}
|
||||
phase += "_" + para->abstract()->ToString();
|
||||
args[i] = *(para->user_data<py::object>("pi_jit_py_obj"));
|
||||
}
|
||||
phase += ".pi_jit";
|
||||
MindCompiler::CompileInfo compile_info{co_name, co_argcount, co_kwonlyargcount, co_flags};
|
||||
CallableGraph callable = mindspore::pijit::MindCompiler::Compile(func_graph, args, py::dict(), phase, compile_info);
|
||||
// Set NativeFunc.
|
||||
auto parent = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
if (make_graph) {
|
||||
parent->code->SetNativeFunc(phase, callable, nullptr);
|
||||
} else {
|
||||
JitCompileResults *child = getJitCompileResults(copy_code.ptr());
|
||||
child->code = child->codehub->AddOptTarget(OptOption::CreateOptionByPoint(child));
|
||||
child->code->SetNativeFunc(phase, callable, nullptr);
|
||||
child->stat = CodeExtra::GRAPH_CALLABLE;
|
||||
child->conf = parent->conf;
|
||||
child->tbs = parent->tbs;
|
||||
}
|
||||
|
||||
return copy_code;
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
|
||||
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
|
||||
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
|
||||
if (make_graph) {
|
||||
co_name = MakeCompiledName(co_name);
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
return MakeCopyCode(AttachCodeID(co_name), co_->co_argcount + co_->co_kwonlyargcount, 0, co_->co_flags, true);
|
||||
}
|
||||
|
||||
CodeGenerator code_gen(&interpret_);
|
||||
code_gen.SetGlobals(GetGlobals());
|
||||
code_gen.Init();
|
||||
for (auto i : captured_.inputs) {
|
||||
code_gen.MarkAlive(i);
|
||||
}
|
||||
code_gen.Build();
|
||||
|
||||
CallCapturedCode(&code_gen);
|
||||
FixInterpretOuput(&code_gen);
|
||||
// ... handle side effects
|
||||
CallUntrackedCode(&code_gen);
|
||||
MakeReturn(&code_gen);
|
||||
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
|
||||
int nlocals = code_gen.GetLocalsMap().size();
|
||||
nlocals = std::max(nlocals, co_->co_nlocals);
|
||||
nlocals = std::max(nlocals, cfg_->GetLocalCount());
|
||||
|
||||
code_gen.SetArgsInfo(co_->co_argcount + co_->co_kwonlyargcount, 0);
|
||||
code_gen.SetLocalsCount(nlocals);
|
||||
code_gen.SetCodeFlags(co_->co_flags);
|
||||
code_gen.SetFirstLineNumber(co_->co_firstlineno);
|
||||
code_gen.SetVariableNames(py::cast<std::vector<std::string>>(co_->co_varnames));
|
||||
code_gen.SetCellVariableNames(py::cast<std::vector<std::string>>(co_->co_cellvars));
|
||||
code_gen.SetFreeVariableNames(py::cast<std::vector<std::string>>(co_->co_freevars));
|
||||
code_gen.SetCodeName(co_name);
|
||||
code_gen.SetFileName(py::reinterpret_borrow<py::object>(co_->co_filename));
|
||||
|
||||
code_gen.EraseUnusedInstr();
|
||||
py::object result = CodeGenerator::Transform(code_gen.GetCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
py::object CodeBreakGenerator::MakeCode(bool make_graph, Graph *graph) {
|
||||
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
|
||||
|
@ -1605,5 +1504,106 @@ std::string PrintNodeSet(const NodeSet &nodes) {
|
|||
return s.str();
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc,
|
||||
int code_flag) const {
|
||||
int flags = co_->co_flags & ~(CO_VARARGS | CO_VARKEYWORDS);
|
||||
return MakeCopyCode(AttachCodeID(MakeCompiledName(py::str(co_->co_name))), argc, 0, flags | code_flag);
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCopyCode(const std::string &co_name, int co_argcount, int co_kwonlyargcount,
|
||||
int co_flags, bool make_graph) const {
|
||||
py::str py_co_name(co_name);
|
||||
PyCodeObject *new_code =
|
||||
PyCode_New(co_argcount, co_kwonlyargcount, co_->co_nlocals, co_->co_stacksize, co_flags, co_->co_code,
|
||||
co_->co_consts, co_->co_names, co_->co_varnames, co_->co_freevars, co_->co_cellvars, co_->co_filename,
|
||||
py_co_name.ptr(), co_->co_firstlineno, co_->co_lnotab);
|
||||
if (new_code == nullptr) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
auto copy_code = py::reinterpret_steal<py::object>(reinterpret_cast<PyObject *>(new_code));
|
||||
// Compile graph.
|
||||
auto b = std::dynamic_pointer_cast<MindGraphBuilder>(builder_);
|
||||
MS_EXCEPTION_IF_NULL(b);
|
||||
auto func_graph = FGBuilder()->graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get function graph from function graph builder failed.";
|
||||
}
|
||||
std::string phase =
|
||||
py::cast<std::string>(co_->co_filename) + "_" + std::to_string(co_->co_firstlineno) + "_" + co_name;
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
py::tuple args(parameters.size() - func_graph->fv_param_count());
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
auto para = parameters[i]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->has_default()) {
|
||||
continue;
|
||||
}
|
||||
phase += "_" + para->abstract()->ToString();
|
||||
args[i] = *(para->user_data<py::object>("pi_jit_py_obj"));
|
||||
}
|
||||
phase += ".pi_jit";
|
||||
MindCompiler::CompileInfo compile_info{co_name, co_argcount, co_kwonlyargcount, co_flags};
|
||||
CallableGraph callable = mindspore::pijit::MindCompiler::Compile(func_graph, args, py::dict(), phase, compile_info);
|
||||
// Set NativeFunc.
|
||||
auto parent = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
if (make_graph) {
|
||||
parent->code->SetNativeFunc(phase, callable, nullptr);
|
||||
} else {
|
||||
JitCompileResults *child = getJitCompileResults(copy_code.ptr());
|
||||
child->code = child->codehub->AddOptTarget(OptOption::CreateOptionByPoint(child));
|
||||
child->code->SetNativeFunc(phase, callable, nullptr);
|
||||
child->stat = CodeExtra::GRAPH_CALLABLE;
|
||||
child->conf = parent->conf;
|
||||
child->tbs = parent->tbs;
|
||||
}
|
||||
|
||||
return copy_code;
|
||||
}
|
||||
|
||||
py::object MindCodeBreakGenerator::MakeCode(bool make_graph) {
|
||||
auto jcr = getJitCompileResults(reinterpret_cast<PyObject *>(co_), false);
|
||||
|
||||
std::string co_name = PyUnicode_AsUTF8(co_->co_name);
|
||||
if (make_graph) {
|
||||
co_name = MakeCompiledName(co_name);
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
return MakeCopyCode(AttachCodeID(co_name), co_->co_argcount + co_->co_kwonlyargcount, 0, co_->co_flags, true);
|
||||
}
|
||||
|
||||
CodeGenerator code_gen(&interpret_);
|
||||
code_gen.SetGlobals(GetGlobals());
|
||||
code_gen.Init();
|
||||
for (auto i : captured_.inputs) {
|
||||
code_gen.MarkAlive(i);
|
||||
}
|
||||
code_gen.Build();
|
||||
|
||||
CallCapturedCode(&code_gen);
|
||||
FixInterpretOuput(&code_gen);
|
||||
// ... handle side effects
|
||||
CallUntrackedCode(&code_gen);
|
||||
MakeReturn(&code_gen);
|
||||
|
||||
co_name = std::to_string(jcr->IncCodeCount()) + "R." + co_name;
|
||||
|
||||
int nlocals = code_gen.GetLocalsMap().size();
|
||||
nlocals = std::max(nlocals, co_->co_nlocals);
|
||||
nlocals = std::max(nlocals, cfg_->GetLocalCount());
|
||||
|
||||
code_gen.SetArgsInfo(co_->co_argcount + co_->co_kwonlyargcount, 0);
|
||||
code_gen.SetLocalsCount(nlocals);
|
||||
code_gen.SetCodeFlags(co_->co_flags);
|
||||
code_gen.SetFirstLineNumber(co_->co_firstlineno);
|
||||
code_gen.SetVariableNames(py::cast<std::vector<std::string>>(co_->co_varnames));
|
||||
code_gen.SetCellVariableNames(py::cast<std::vector<std::string>>(co_->co_cellvars));
|
||||
code_gen.SetFreeVariableNames(py::cast<std::vector<std::string>>(co_->co_freevars));
|
||||
code_gen.SetCodeName(co_name);
|
||||
code_gen.SetFileName(py::reinterpret_borrow<py::object>(co_->co_filename));
|
||||
|
||||
code_gen.EraseUnusedInstr();
|
||||
py::object result = CodeGenerator::Transform(code_gen.GetCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -320,41 +320,6 @@ void GraphAnalyzer::UseDefAnalyze() {
|
|||
graph_->SetOldBreakBci(graph_->GetStopTraceBci());
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UpdateCapturedOrder() {
|
||||
const auto &traced_nodes = graph_->GetTracedNodes();
|
||||
auto stop_bci = graph_->GetStopTraceBci();
|
||||
if (stop_bci == -1) {
|
||||
GetCaptureInfo().captured_locals.order = traced_nodes;
|
||||
} else {
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
for (const auto &traced_node : traced_nodes) {
|
||||
if (traced_node->bci() >= stop_bci) {
|
||||
break;
|
||||
}
|
||||
GetCaptureInfo().captured_locals.order.push_back(traced_node);
|
||||
}
|
||||
}
|
||||
const auto &captured_local_order = GetCaptureInfo().captured_locals.order;
|
||||
std::set<ValueNode *> new_capture_local_values(captured_local_order.begin(), captured_local_order.end());
|
||||
GetCaptureInfo().captured_locals.values = new_capture_local_values;
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UseDefAnalyze() {
|
||||
// UD analyze: alive nodes analysis
|
||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||
if (!aliveLocals.empty()) {
|
||||
bool stop_analyze = false;
|
||||
while (!stop_analyze) {
|
||||
UpdateCapturedOrder();
|
||||
// Add graph output according to leaf nodes.
|
||||
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (!stop_analyze) {
|
||||
aliveLocals = GetAliveLocals(graph_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphAnalyzer::Analyze() {
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||
|
@ -402,111 +367,6 @@ void GraphAnalyzer::Analyze() {
|
|||
}
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::CollectInputs() {
|
||||
auto &inputs = GetCaptureInfo().captured_locals.inputs;
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
PyCodeObject *co = graph_->GetCodeObj();
|
||||
int argc = co->co_argcount + co->co_kwonlyargcount;
|
||||
argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
|
||||
argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
|
||||
for (Py_ssize_t m = 0; m < argc; ++m) {
|
||||
auto local = enter_frame.Local(m);
|
||||
if (local != &ValueNode::kUnboundLocal) {
|
||||
inputs.insert(enter_frame.Local(m));
|
||||
} else {
|
||||
const Py_ssize_t ncells = PyTuple_GET_SIZE(co->co_cellvars);
|
||||
for (Py_ssize_t i = 0; co->co_cell2arg && i < ncells; ++i) {
|
||||
Py_ssize_t argi = co->co_cell2arg[i];
|
||||
if (argi != CO_CELL_NOT_AN_ARG) {
|
||||
auto cell = enter_frame.Closure(i)->GetValue();
|
||||
inputs.insert(cell);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::Analyze() {
|
||||
auto origin_stop_bci = graph_->GetStopTraceBci();
|
||||
UseDefAnalyze();
|
||||
CollectInputs();
|
||||
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||
if (func_graph_builder->graph() == nullptr) {
|
||||
// Graph build failed, add all nodes to ordered_escaped_locals.
|
||||
MS_LOG(DEBUG) << "Failed to build graph";
|
||||
GetCaptureInfo().ordered_escaped_locals.clear();
|
||||
for (const auto &traced_node : graph_->GetTracedNodes()) {
|
||||
if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
|
||||
break;
|
||||
}
|
||||
AddToEscaped(traced_node);
|
||||
}
|
||||
graph_->StopTraceAt(origin_stop_bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
|
||||
need_interpret_ = true;
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
GetCaptureInfo().captured_locals.values.clear();
|
||||
GetCaptureInfo().captured_locals.inputs.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
need_interpret_ = true;
|
||||
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
|
||||
return;
|
||||
}
|
||||
bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
|
||||
if (!support_ret) {
|
||||
return;
|
||||
}
|
||||
need_interpret_ = false;
|
||||
}
|
||||
|
||||
bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
|
||||
bool isAllNodesSupportOutput = true;
|
||||
for (auto node : aliveNodes) {
|
||||
// If the value can get from local, no need to add to graph output.
|
||||
if (IsNonLocalValue(node)) {
|
||||
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
|
||||
continue;
|
||||
}
|
||||
auto capturedLocals = info_.captured_locals.order;
|
||||
if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
|
||||
continue;
|
||||
}
|
||||
AObject *o = node->GetVobj();
|
||||
auto out_py_obj = o->GetPyObject();
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||
if (func_graph_builder->AddOutput(out_py_obj, false)) {
|
||||
MS_LOG(DEBUG) << "Add output success.";
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Add output failed.";
|
||||
// reset break graph point
|
||||
isAllNodesSupportOutput = false;
|
||||
int new_break_point = node->bci();
|
||||
auto curNode = node;
|
||||
if (new_break_point == -1) {
|
||||
// No node is unsupported output since no node in captured output.
|
||||
isAllNodesSupportOutput = true;
|
||||
break;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
|
||||
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
|
||||
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
|
||||
}
|
||||
this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
|
||||
break;
|
||||
}
|
||||
return isAllNodesSupportOutput;
|
||||
}
|
||||
|
||||
FrameStates buildLastFrame(Graph *g) { return g->GetFrame(g->GetStopTraceBci()); }
|
||||
|
||||
std::vector<ValueNode *> GraphAnalyzer::GetAliveLocals(Graph *g) {
|
||||
|
@ -642,5 +502,145 @@ bool ValidateGraphParameters(ValueNode *node) {
|
|||
return unsupported_parameter.find(info->GetType()) == unsupported_parameter.end();
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::CollectInputs() {
|
||||
auto &inputs = GetCaptureInfo().captured_locals.inputs;
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
PyCodeObject *co = graph_->GetCodeObj();
|
||||
int argc = co->co_argcount + co->co_kwonlyargcount;
|
||||
argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
|
||||
argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
|
||||
for (Py_ssize_t m = 0; m < argc; ++m) {
|
||||
auto local = enter_frame.Local(m);
|
||||
if (local != &ValueNode::kUnboundLocal) {
|
||||
inputs.insert(enter_frame.Local(m));
|
||||
} else {
|
||||
const Py_ssize_t ncells = PyTuple_GET_SIZE(co->co_cellvars);
|
||||
for (Py_ssize_t i = 0; co->co_cell2arg && i < ncells; ++i) {
|
||||
Py_ssize_t argi = co->co_cell2arg[i];
|
||||
if (argi != CO_CELL_NOT_AN_ARG) {
|
||||
auto cell = enter_frame.Closure(i)->GetValue();
|
||||
inputs.insert(cell);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::Analyze() {
|
||||
auto origin_stop_bci = graph_->GetStopTraceBci();
|
||||
UseDefAnalyze();
|
||||
CollectInputs();
|
||||
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||
if (func_graph_builder->graph() == nullptr) {
|
||||
// Graph build failed, add all nodes to ordered_escaped_locals.
|
||||
MS_LOG(DEBUG) << "Failed to build graph";
|
||||
GetCaptureInfo().ordered_escaped_locals.clear();
|
||||
for (const auto &traced_node : graph_->GetTracedNodes()) {
|
||||
if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
|
||||
break;
|
||||
}
|
||||
AddToEscaped(traced_node);
|
||||
}
|
||||
graph_->StopTraceAt(origin_stop_bci, StopTraceReason::kStopTraceDataDependsOnGraphOut);
|
||||
need_interpret_ = true;
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
GetCaptureInfo().captured_locals.values.clear();
|
||||
GetCaptureInfo().captured_locals.inputs.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
need_interpret_ = true;
|
||||
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
|
||||
return;
|
||||
}
|
||||
bool support_ret = graph_->GetRetVal()->GetVobj() && graph_->GetRetVal()->GetVobj()->IsMindSporeSupportedType();
|
||||
if (!support_ret) {
|
||||
return;
|
||||
}
|
||||
need_interpret_ = false;
|
||||
}
|
||||
|
||||
bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) {
|
||||
bool isAllNodesSupportOutput = true;
|
||||
for (auto node : aliveNodes) {
|
||||
// If the value can get from local, no need to add to graph output.
|
||||
if (IsNonLocalValue(node)) {
|
||||
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
|
||||
continue;
|
||||
}
|
||||
auto capturedLocals = info_.captured_locals.order;
|
||||
if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
|
||||
continue;
|
||||
}
|
||||
AObject *o = node->GetVobj();
|
||||
auto out_py_obj = o->GetPyObject();
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||
if (func_graph_builder->AddOutput(out_py_obj, false)) {
|
||||
MS_LOG(DEBUG) << "Add output success.";
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Add output failed.";
|
||||
// reset break graph point
|
||||
isAllNodesSupportOutput = false;
|
||||
int new_break_point = node->bci();
|
||||
auto curNode = node;
|
||||
if (new_break_point == -1) {
|
||||
// No node is unsupported output since no node in captured output.
|
||||
isAllNodesSupportOutput = true;
|
||||
break;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
|
||||
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
|
||||
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
|
||||
}
|
||||
this->graph_->StopTraceAt(new_break_point, StopTraceReason::kStopTraceDataDependsOnGraphOut);
|
||||
break;
|
||||
}
|
||||
return isAllNodesSupportOutput;
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UpdateCapturedOrder() {
|
||||
const auto &traced_nodes = graph_->GetTracedNodes();
|
||||
auto stop_bci = graph_->GetStopTraceBci();
|
||||
if (stop_bci == -1) {
|
||||
GetCaptureInfo().captured_locals.order = traced_nodes;
|
||||
} else {
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
for (const auto &traced_node : traced_nodes) {
|
||||
if (traced_node->bci() >= stop_bci) {
|
||||
break;
|
||||
}
|
||||
GetCaptureInfo().captured_locals.order.push_back(traced_node);
|
||||
}
|
||||
}
|
||||
const auto &captured_local_order = GetCaptureInfo().captured_locals.order;
|
||||
std::set<ValueNode *> new_capture_local_values(captured_local_order.begin(), captured_local_order.end());
|
||||
GetCaptureInfo().captured_locals.values = new_capture_local_values;
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UseDefAnalyze() {
|
||||
// UD analyze: alive nodes analysis
|
||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||
if (!aliveLocals.empty()) {
|
||||
bool stop_analyze = false;
|
||||
while (!stop_analyze) {
|
||||
UpdateCapturedOrder();
|
||||
// Add graph output according to leaf nodes.
|
||||
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (!stop_analyze) {
|
||||
aliveLocals = GetAliveLocals(graph_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1554,26 +1554,6 @@ py::object GraphBuilder::GetFuncInfo(ValueNode *func_node) {
|
|||
return FindPyFunc(vobj);
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
|
||||
std::string special_func_key;
|
||||
if (IsFuncInWhiteList(callable, &special_func_key)) {
|
||||
call_node->SetSubGraph(NewGraph(nullptr, nullptr));
|
||||
call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
|
||||
bool has_sub_graph = HandleFuncInWhiteList(special_func_key, call_node);
|
||||
if (!has_sub_graph) {
|
||||
call_node->SetInlineReason(InlineReason::kInlineFuncSpecialize);
|
||||
MS_ASSERT(!call_node->GetSubGraph()); // check infer function
|
||||
return true;
|
||||
}
|
||||
call_node->SetInlineReason(InlineReason::kInline);
|
||||
ValueNode *ret_node = call_node->GetSubGraph()->GetRetVal();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(ret_node, "infer special function failed");
|
||||
seek(0) = ret_node;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
|
||||
const auto &conf = call_node->GetGraph()->Config();
|
||||
|
||||
|
@ -2534,288 +2514,6 @@ bool GraphBuilder::HandleCallParameters(const py::object &func_info, CallNode *c
|
|||
|
||||
static void SetGradFuncInfo(mindspore::pijit::CallNode *call_node);
|
||||
|
||||
void MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
|
||||
// Add function graph inputs.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
MS_LOG(INFO) << "try add input: " << py::str(args[i]);
|
||||
FGBuilder()->AddInput(args[i]);
|
||||
MS_LOG(INFO) << "add input suc";
|
||||
}
|
||||
}
|
||||
|
||||
void MindGraphBuilder::FGAddOutput() {
|
||||
if (auto ret = GetGraph()->GetRetVal()) {
|
||||
MS_LOG(INFO) << ret->GetVobj()->ToString();
|
||||
auto out = ret->GetVobj()->GetPyObject();
|
||||
MS_LOG(INFO) << "try add output: " << py::str(out) << " addr:" << out.ptr();
|
||||
if (FGBuilder()->AddOutput(out)) {
|
||||
MS_LOG(INFO) << "add output succuss";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "add output fail";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object MindGraphBuilder::FGAddNode(CallNode *call_node, const py::object &callable_info,
|
||||
const std::vector<py::object> &args, StopTraceReason *stop_reason) {
|
||||
MS_LOG(INFO) << "try add node: " << py::str(callable_info);
|
||||
TraceGuard trace_guard(GetLocation(call_node));
|
||||
auto res = FGBuilder()->AddNode(callable_info, args);
|
||||
if (res.ptr() == nullptr) {
|
||||
MS_LOG(ERROR) << "add node fail";
|
||||
*stop_reason = StopTraceReason::kTrace_Fail;
|
||||
} else {
|
||||
MS_LOG(INFO) << "add node suc";
|
||||
auto node = AbstractTraceNode::MakeAObject(res);
|
||||
MS_LOG(INFO) << py::str(node->GetPyObject());
|
||||
MS_LOG(INFO) << node->ToString();
|
||||
call_node->SetVobj(node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
std::vector<py::object> MindGraphBuilder::GetNewArgs(CallNode *call_node, AObject *vobj) {
|
||||
std::vector<py::object> new_args;
|
||||
vobj = vobj ? vobj : call_node->GetVobj();
|
||||
if (vobj->GetType() == AObject::kTypeCFunction) {
|
||||
MS_LOG(ERROR) << "not support cfunction";
|
||||
}
|
||||
auto new_callable_info = FindPyFunc(vobj);
|
||||
FrameStates f;
|
||||
ResolveClosure(new_callable_info, call_node->input(0), &f);
|
||||
if (!HandleCallParameters(new_callable_info, call_node, &f)) {
|
||||
MS_LOG(ERROR) << "HandleCallParameters error" << std::endl;
|
||||
}
|
||||
PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(new_callable_info.ptr()));
|
||||
int argc = co->co_argcount + co->co_kwonlyargcount;
|
||||
argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
|
||||
argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
|
||||
for (auto it = f.GetLocals().begin(); it != f.GetLocals().begin() + argc; it++) {
|
||||
std::set<AObject::Type> unsupported_parameter = {
|
||||
AObject::kTypeAnyValue, AObject::kTypeFunction, AObject::kTypeBoundMethod,
|
||||
AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
|
||||
};
|
||||
auto vobj = (*it)->GetVobj();
|
||||
if (vobj != nullptr) {
|
||||
auto pyobj = vobj->GetPyObject();
|
||||
if (pyobj.ptr() != nullptr) {
|
||||
if (unsupported_parameter.find(AbstractObjectBase::GetPyType(pyobj.ptr())) == unsupported_parameter.end()) {
|
||||
new_args.push_back(pyobj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new_args;
|
||||
}
|
||||
|
||||
py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
|
||||
AObject *callable = call_node->input(0)->GetVobj();
|
||||
py::object callable_info;
|
||||
*stop_reason = StopTraceReason::kStopTraceInfer_Fail;
|
||||
call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
|
||||
if (!callable) {
|
||||
return callable_info;
|
||||
}
|
||||
callable_info = callable->GetPyObject();
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
return py::object();
|
||||
}
|
||||
MS_LOG(INFO) << "trace_flag for: " << py::str(callable_info);
|
||||
auto args = call_node->GetArgs();
|
||||
auto method = FGBuilder()->ConvertMethod(callable_info);
|
||||
if (method.ptr() != nullptr) {
|
||||
MS_LOG(INFO) << "convert method :" << py::str(callable_info) << " to " << py::str(method);
|
||||
callable_info = method;
|
||||
args = GetNewArgs(call_node, AObject::Convert(callable_info.ptr()));
|
||||
}
|
||||
auto func = FGBuilder()->ConvertFunction(callable_info);
|
||||
if (func.ptr() != nullptr) {
|
||||
MS_LOG(INFO) << "convert function:" << py::str(callable_info) << " to " << py::str(func);
|
||||
callable_info = func;
|
||||
}
|
||||
if (FGBuilder()->CheckCallable(callable_info)) {
|
||||
if (PyFunction_Check(callable_info.ptr())) {
|
||||
args = GetNewArgs(call_node);
|
||||
}
|
||||
return FGAddNode(call_node, callable_info, args, stop_reason);
|
||||
}
|
||||
if (FGBuilder()->CanConstantFoldFunc(callable_info)) {
|
||||
MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
|
||||
JustCallAndSetRes(call_node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
return py::object();
|
||||
}
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
callable_info = py::cast<py::object>(reinterpret_cast<PyObject *>(callable->GetTypeObject()));
|
||||
}
|
||||
|
||||
AObject::Type callable_type = callable->GetType();
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
if (callable->TestMsFlag(AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc)) {
|
||||
SetGradFuncInfo(call_node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
if (callable_type == AObject::kTypeType) {
|
||||
call_node->SetInlineReason(InlineReason::kInlineFunc_ArgType_IsClass);
|
||||
HandleCallClass(call_node);
|
||||
if (static_cast<AbstractType *>(callable)->GetTypeType() == AObject::kTypeCell) {
|
||||
*stop_reason = StopTraceReason::kStopTraceInfer_Fail;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
if (WhiteListFuncCheckAndInfer(call_node, callable_info)) {
|
||||
return py::object();
|
||||
}
|
||||
|
||||
// find code object
|
||||
auto vobj = AObject::Convert(callable_info.ptr());
|
||||
if (vobj->GetType() == AObject::kTypeCFunction) {
|
||||
callable_info = py::object();
|
||||
}
|
||||
callable_info = FindPyFunc(vobj);
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
*stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
|
||||
call_node->SetInlineReason(InlineReason::kInlineCFunction_Unsupported);
|
||||
}
|
||||
return callable_info;
|
||||
}
|
||||
|
||||
AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare) {
|
||||
int opcode = instr.op();
|
||||
int oparg = instr.arg();
|
||||
std::vector<py::object> input_obj;
|
||||
for (auto input : p) {
|
||||
if (input->GetVobj() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
(void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
|
||||
}
|
||||
const auto &op_name =
|
||||
is_compare ? pijit::GraphUtils::OpCompareArgToGraphName(oparg) : pijit::GraphUtils::OpCodeToGraphName(opcode);
|
||||
MS_LOG(DEBUG) << "operation name is " << op_name;
|
||||
if (op_name == "") {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
auto node = fg_builder_->AddMultiNode(op_name, input_obj);
|
||||
if (node.ptr() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
return AbstractTraceNode::MakeAObject(node);
|
||||
}
|
||||
|
||||
AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
|
||||
auto opcode = instr.op();
|
||||
std::vector<py::object> input_obj;
|
||||
for (auto input : p) {
|
||||
if (input->GetVobj() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
(void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
|
||||
}
|
||||
auto primitive = pijit::GraphUtils::GetPrimitive(opcode);
|
||||
if (primitive == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
if (primitive == prim::kPrimMakeDict) {
|
||||
if (opcode == BUILD_CONST_KEY_MAP) {
|
||||
MS_LOG(DEBUG) << "BUILD_CONST_KEY_MAP case, need to pack values.";
|
||||
std::vector<py::object> value_inputs;
|
||||
(void)std::transform(input_obj.begin(), input_obj.end() - 1, std::back_inserter(value_inputs),
|
||||
[](const py::object &obj) { return obj; });
|
||||
auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_inputs);
|
||||
input_obj = {input_obj.back(), value_node};
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "BUILD_KEY_MAP case, need to pack keys and values.";
|
||||
size_t input_len = input_obj.size();
|
||||
if (input_len % 2 != 0) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "BUILD_KEY_MAP should have even input, but got: " << input_len;
|
||||
}
|
||||
std::vector<py::object> key_obj;
|
||||
std::vector<py::object> value_obj;
|
||||
for (size_t i = 0; i < input_len / 2; ++i) {
|
||||
key_obj.push_back(input_obj[2 * i]);
|
||||
value_obj.push_back(input_obj[2 * i + 1]);
|
||||
}
|
||||
auto key_node = fg_builder_->AddNode(prim::kPrimMakeTuple, key_obj);
|
||||
auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_obj);
|
||||
input_obj = {key_node, value_node};
|
||||
}
|
||||
}
|
||||
if (primitive == prim::kPrimMakeSlice) {
|
||||
constexpr size_t slice_without_step_len = 2;
|
||||
if (input_obj.size() == slice_without_step_len) {
|
||||
// Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
|
||||
(void)input_obj.emplace_back(py::int_(1));
|
||||
}
|
||||
}
|
||||
auto node = fg_builder_->AddNode(primitive, input_obj);
|
||||
return AbstractTraceNode::MakeAObject(node);
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoGetItem(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoUnary(const Instr &instr) {
|
||||
auto o = pop();
|
||||
auto r = HandleMultiOp(instr, {o}, false);
|
||||
auto v = NewValueNode(r, instr, {o});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBinary(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBinaryMul(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoCompare(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, true);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBuildOp(const Instr &instr) {
|
||||
int opcode = instr.op();
|
||||
int oparg = instr.arg();
|
||||
int tmp_arg = oparg;
|
||||
tmp_arg += opcode == BUILD_CONST_KEY_MAP;
|
||||
tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
|
||||
std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
|
||||
auto o = HandleBuildOp(instr, p);
|
||||
popn(tmp_arg);
|
||||
auto v = NewValueNode(o, instr, p);
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
py::object GraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
|
||||
AObject *callable = call_node->input(0)->GetVobj();
|
||||
py::object callable_info;
|
||||
|
@ -3545,5 +3243,325 @@ LocationPtr MindGraphBuilder::GetLocation(CallNode *call_node) const {
|
|||
std::vector<std::string> comments;
|
||||
return std::make_shared<Location>(file_name, line_no, 0, line_no, 0, "", std::move(comments));
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::WhiteListFuncCheckAndInfer(CallNode *call_node, const py::object &callable) {
|
||||
std::string special_func_key;
|
||||
if (IsFuncInWhiteList(callable, &special_func_key)) {
|
||||
call_node->SetSubGraph(NewGraph(nullptr, nullptr));
|
||||
call_node->GetSubGraph()->SetGuard(root_->GetGraph()->GetGuard());
|
||||
bool has_sub_graph = HandleFuncInWhiteList(special_func_key, call_node);
|
||||
if (!has_sub_graph) {
|
||||
call_node->SetInlineReason(InlineReason::kInlineFuncSpecialize);
|
||||
MS_ASSERT(!call_node->GetSubGraph()); // check infer function
|
||||
return true;
|
||||
}
|
||||
call_node->SetInlineReason(InlineReason::kInline);
|
||||
ValueNode *ret_node = call_node->GetSubGraph()->GetRetVal();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(ret_node, "infer special function failed");
|
||||
seek(0) = ret_node;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::string GetFuncGraphName(const py::object &func, const MindGraphBuilderPtr &subgraph) {
|
||||
auto func_str = py::cast<std::string>(py::str(func));
|
||||
std::vector<std::string> vec;
|
||||
std::istringstream iss(func_str);
|
||||
std::string str;
|
||||
while (iss >> str) {
|
||||
(void)vec.emplace_back(str);
|
||||
}
|
||||
if (vec.size() <= 1) {
|
||||
return "";
|
||||
}
|
||||
auto func_name = vec[1];
|
||||
std::replace(func_name.begin(), func_name.end(), '.', '_');
|
||||
return func_name + "_" + std::to_string(subgraph->GetGraph()->GetCodeObj()->co_firstlineno);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MindGraphBuilder::FGAddInputs(const std::vector<py::object> &args) {
|
||||
// Add function graph inputs.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
MS_LOG(INFO) << "try add input: " << py::str(args[i]);
|
||||
FGBuilder()->AddInput(args[i]);
|
||||
MS_LOG(INFO) << "add input suc";
|
||||
}
|
||||
}
|
||||
|
||||
void MindGraphBuilder::FGAddOutput() {
|
||||
if (auto ret = GetGraph()->GetRetVal()) {
|
||||
MS_LOG(INFO) << ret->GetVobj()->ToString();
|
||||
auto out = ret->GetVobj()->GetPyObject();
|
||||
MS_LOG(INFO) << "try add output: " << py::str(out) << " addr:" << out.ptr();
|
||||
if (FGBuilder()->AddOutput(out)) {
|
||||
MS_LOG(INFO) << "add output succuss";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "add output fail";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object MindGraphBuilder::FGAddNode(CallNode *call_node, const py::object &callable_info,
|
||||
const std::vector<py::object> &args, StopTraceReason *stop_reason) {
|
||||
MS_LOG(INFO) << "try add node: " << py::str(callable_info);
|
||||
TraceGuard trace_guard(GetLocation(call_node));
|
||||
auto res = FGBuilder()->AddNode(callable_info, args);
|
||||
if (res.ptr() == nullptr) {
|
||||
MS_LOG(ERROR) << "add node fail";
|
||||
*stop_reason = StopTraceReason::kTrace_Fail;
|
||||
} else {
|
||||
MS_LOG(INFO) << "add node suc";
|
||||
auto node = AbstractTraceNode::MakeAObject(res);
|
||||
MS_LOG(INFO) << py::str(node->GetPyObject());
|
||||
MS_LOG(INFO) << node->ToString();
|
||||
call_node->SetVobj(node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
std::vector<py::object> MindGraphBuilder::GetNewArgs(CallNode *call_node, AObject *vobj) {
|
||||
std::vector<py::object> new_args;
|
||||
vobj = vobj ? vobj : call_node->GetVobj();
|
||||
if (vobj->GetType() == AObject::kTypeCFunction) {
|
||||
MS_LOG(ERROR) << "not support cfunction";
|
||||
}
|
||||
auto new_callable_info = FindPyFunc(vobj);
|
||||
FrameStates f;
|
||||
ResolveClosure(new_callable_info, call_node->input(0), &f);
|
||||
if (!HandleCallParameters(new_callable_info, call_node, &f)) {
|
||||
MS_LOG(ERROR) << "HandleCallParameters error" << std::endl;
|
||||
}
|
||||
PyCodeObject *co = reinterpret_cast<PyCodeObject *>(PyFunction_GET_CODE(new_callable_info.ptr()));
|
||||
int argc = co->co_argcount + co->co_kwonlyargcount;
|
||||
argc += (co->co_flags & CO_VARARGS) ? 1 : 0;
|
||||
argc += (co->co_flags & CO_VARKEYWORDS) ? 1 : 0;
|
||||
for (auto it = f.GetLocals().begin(); it != f.GetLocals().begin() + argc; it++) {
|
||||
std::set<AObject::Type> unsupported_parameter = {
|
||||
AObject::kTypeAnyValue, AObject::kTypeFunction, AObject::kTypeBoundMethod,
|
||||
AObject::kTypePrimitive, AObject::kTypeMetaFuncGraph, AObject::kTypeCell,
|
||||
};
|
||||
auto vobj = (*it)->GetVobj();
|
||||
if (vobj != nullptr) {
|
||||
auto pyobj = vobj->GetPyObject();
|
||||
if (pyobj.ptr() != nullptr) {
|
||||
if (unsupported_parameter.find(AbstractObjectBase::GetPyType(pyobj.ptr())) == unsupported_parameter.end()) {
|
||||
new_args.push_back(pyobj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new_args;
|
||||
}
|
||||
|
||||
py::object MindGraphBuilder::ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) {
|
||||
AObject *callable = call_node->input(0)->GetVobj();
|
||||
py::object callable_info;
|
||||
*stop_reason = StopTraceReason::kStopTraceInfer_Fail;
|
||||
call_node->SetInlineReason(InlineReason::kInlineInfer_Fail);
|
||||
if (!callable) {
|
||||
return callable_info;
|
||||
}
|
||||
callable_info = callable->GetPyObject();
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
return py::object();
|
||||
}
|
||||
MS_LOG(INFO) << "trace_flag for: " << py::str(callable_info);
|
||||
auto args = call_node->GetArgs();
|
||||
auto method = FGBuilder()->ConvertMethod(callable_info);
|
||||
if (method.ptr() != nullptr) {
|
||||
MS_LOG(INFO) << "convert method :" << py::str(callable_info) << " to " << py::str(method);
|
||||
callable_info = method;
|
||||
args = GetNewArgs(call_node, AObject::Convert(callable_info.ptr()));
|
||||
}
|
||||
auto func = FGBuilder()->ConvertFunction(callable_info);
|
||||
if (func.ptr() != nullptr) {
|
||||
MS_LOG(INFO) << "convert function:" << py::str(callable_info) << " to " << py::str(func);
|
||||
callable_info = func;
|
||||
}
|
||||
if (FGBuilder()->CheckCallable(callable_info)) {
|
||||
if (PyFunction_Check(callable_info.ptr())) {
|
||||
args = GetNewArgs(call_node);
|
||||
}
|
||||
return FGAddNode(call_node, callable_info, args, stop_reason);
|
||||
}
|
||||
if (FGBuilder()->CanConstantFoldFunc(callable_info)) {
|
||||
MS_LOG(INFO) << "CanConstantFoldFunc for: " << py::str(callable_info);
|
||||
JustCallAndSetRes(call_node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
return py::object();
|
||||
}
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
callable_info = py::cast<py::object>(reinterpret_cast<PyObject *>(callable->GetTypeObject()));
|
||||
}
|
||||
|
||||
AObject::Type callable_type = callable->GetType();
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
if (callable->TestMsFlag(AObject::kMsFlagGradFunc | AObject::kMsFlagShardFunc | AObject::kMsFlagVmapFunc)) {
|
||||
SetGradFuncInfo(call_node);
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
*stop_reason = StopTraceReason::kNonStopTrace;
|
||||
if (callable_type == AObject::kTypeType) {
|
||||
call_node->SetInlineReason(InlineReason::kInlineFunc_ArgType_IsClass);
|
||||
HandleCallClass(call_node);
|
||||
if (static_cast<AbstractType *>(callable)->GetTypeType() == AObject::kTypeCell) {
|
||||
*stop_reason = StopTraceReason::kStopTraceInfer_Fail;
|
||||
}
|
||||
return py::object();
|
||||
}
|
||||
|
||||
if (WhiteListFuncCheckAndInfer(call_node, callable_info)) {
|
||||
return py::object();
|
||||
}
|
||||
|
||||
// find code object
|
||||
auto vobj = AObject::Convert(callable_info.ptr());
|
||||
if (vobj->GetType() == AObject::kTypeCFunction) {
|
||||
callable_info = py::object();
|
||||
}
|
||||
callable_info = FindPyFunc(vobj);
|
||||
if (callable_info.ptr() == nullptr) {
|
||||
*stop_reason = StopTraceReason::kStopTraceFunc_Type_Unsupported;
|
||||
call_node->SetInlineReason(InlineReason::kInlineCFunction_Unsupported);
|
||||
}
|
||||
return callable_info;
|
||||
}
|
||||
|
||||
AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare) {
|
||||
int opcode = instr.op();
|
||||
int oparg = instr.arg();
|
||||
std::vector<py::object> input_obj;
|
||||
for (auto input : p) {
|
||||
if (input->GetVobj() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
(void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
|
||||
}
|
||||
const auto &op_name =
|
||||
is_compare ? pijit::GraphUtils::OpCompareArgToGraphName(oparg) : pijit::GraphUtils::OpCodeToGraphName(opcode);
|
||||
MS_LOG(DEBUG) << "operation name is " << op_name;
|
||||
if (op_name == "") {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
auto node = fg_builder_->AddMultiNode(op_name, input_obj);
|
||||
if (node.ptr() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
return AbstractTraceNode::MakeAObject(node);
|
||||
}
|
||||
|
||||
AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
|
||||
auto opcode = instr.op();
|
||||
std::vector<py::object> input_obj;
|
||||
for (auto input : p) {
|
||||
if (input->GetVobj() == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
(void)input_obj.emplace_back(input->GetVobj()->GetPyObject());
|
||||
}
|
||||
auto primitive = pijit::GraphUtils::GetPrimitive(opcode);
|
||||
if (primitive == nullptr) {
|
||||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
if (primitive == prim::kPrimMakeDict) {
|
||||
if (opcode == BUILD_CONST_KEY_MAP) {
|
||||
MS_LOG(DEBUG) << "BUILD_CONST_KEY_MAP case, need to pack values.";
|
||||
std::vector<py::object> value_inputs;
|
||||
(void)std::transform(input_obj.begin(), input_obj.end() - 1, std::back_inserter(value_inputs),
|
||||
[](const py::object &obj) { return obj; });
|
||||
auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_inputs);
|
||||
input_obj = {input_obj.back(), value_node};
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "BUILD_KEY_MAP case, need to pack keys and values.";
|
||||
size_t input_len = input_obj.size();
|
||||
if (input_len % 2 != 0) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "BUILD_KEY_MAP should have even input, but got: " << input_len;
|
||||
}
|
||||
std::vector<py::object> key_obj;
|
||||
std::vector<py::object> value_obj;
|
||||
for (size_t i = 0; i < input_len / 2; ++i) {
|
||||
key_obj.push_back(input_obj[2 * i]);
|
||||
value_obj.push_back(input_obj[2 * i + 1]);
|
||||
}
|
||||
auto key_node = fg_builder_->AddNode(prim::kPrimMakeTuple, key_obj);
|
||||
auto value_node = fg_builder_->AddNode(prim::kPrimMakeTuple, value_obj);
|
||||
input_obj = {key_node, value_node};
|
||||
}
|
||||
}
|
||||
if (primitive == prim::kPrimMakeSlice) {
|
||||
constexpr size_t slice_without_step_len = 2;
|
||||
if (input_obj.size() == slice_without_step_len) {
|
||||
// Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
|
||||
(void)input_obj.emplace_back(py::int_(1));
|
||||
}
|
||||
}
|
||||
auto node = fg_builder_->AddNode(primitive, input_obj);
|
||||
return AbstractTraceNode::MakeAObject(node);
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoGetItem(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoUnary(const Instr &instr) {
|
||||
auto o = pop();
|
||||
auto r = HandleMultiOp(instr, {o}, false);
|
||||
auto v = NewValueNode(r, instr, {o});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBinary(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBinaryMul(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, false);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoCompare(const Instr &instr) {
|
||||
auto r = pop();
|
||||
auto l = pop();
|
||||
auto o = HandleMultiOp(instr, {l, r}, true);
|
||||
auto v = NewValueNode(o, instr, {l, r});
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoBuildOp(const Instr &instr) {
|
||||
int opcode = instr.op();
|
||||
int oparg = instr.arg();
|
||||
int tmp_arg = oparg;
|
||||
tmp_arg += opcode == BUILD_CONST_KEY_MAP;
|
||||
tmp_arg += opcode == BUILD_MAP ? tmp_arg : 0;
|
||||
std::vector<ValueNode *> p(frame_.GetStacks().end() - tmp_arg, frame_.GetStacks().end());
|
||||
auto o = HandleBuildOp(instr, p);
|
||||
popn(tmp_arg);
|
||||
auto v = NewValueNode(o, instr, p);
|
||||
push(v);
|
||||
return true;
|
||||
}
|
||||
} // namespace pijit
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue