fix ud for one stage

This commit is contained in:
liangzhibo 2024-02-21 15:35:20 +08:00 committed by r1chardf1d0
parent bc9f94dd1b
commit 8fce13d6bc
4 changed files with 112 additions and 39 deletions

View File

@ -330,10 +330,24 @@ py::object FuncGraphBuilder::TryToAddNode(const ValuePtr &callable_value, const
auto new_node = graph_->NewCNode(input_node_list);
// Return the converted python object.
py::object output_py_obj = ConvertToPyObj(abs);
if (output_py_obj.ptr() == nullptr) {
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
return py::object();
py::object output_py_obj;
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
auto fg = abs_func->func_graph();
if (fg == nullptr) {
return py::object();
}
auto obj = fg->python_obj();
if (obj == nullptr || !obj->isa<parse::PyObjectWrapper>()) {
return py::object();
}
output_py_obj = obj->cast_ptr<parse::PyObjectWrapper>()->obj();
} else {
output_py_obj = ConvertToPyObj(abs);
if (output_py_obj.ptr() == nullptr) {
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
return py::object();
}
}
new_node->set_abstract(abs);

View File

@ -1570,31 +1570,6 @@ py::object MakeCodeFromCodeGen(const GraphBuilderPtr &builder, const GraphAnalyz
GraphAnalyzer::CapturedInfo info = analyzer->GetCaptureInfo();
auto cg = CodeBreakGenerator::Creator(builder, graph->GetCodeObj());
if (builder->trace_flag()) {
if (analyzer->NeedInterpret()) {
std::vector<ValueNode *> new_capture_locals;
const auto &traced_nodes = graph->GetTracedNodes();
auto break_bci = graph->GetStopTraceBci();
for (const auto &node : traced_nodes) {
if (node->bci() >= break_bci) {
break;
}
new_capture_locals.push_back(node);
}
info.captured_locals.order = new_capture_locals;
// Collect inputs for break capture graph.
auto &values = info.captured_locals.values;
auto &inputs = info.captured_locals.inputs;
for (ValueNode *i : info.captured_locals.order) {
for (auto input : i->getInputs()) {
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
continue;
}
inputs.insert(input);
}
}
} else {
info.captured_locals.order = graph->GetTracedNodes();
}
auto mind_builder = std::dynamic_pointer_cast<MindGraphBuilder>(builder);
auto mind_fg_builder = mind_builder->FGBuilder();
MS_EXCEPTION_IF_NULL(mind_fg_builder);

View File

@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() {
// UD analyze: alive nodes analysis
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
if (!aliveLocals.empty()) {
bool isStopAnalyze = false;
while (!isStopAnalyze) {
isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
if (isStopAnalyze) {
bool stop_analyze = false;
while (!stop_analyze) {
stop_analyze = AnalyzeAliveLocals(aliveLocals);
if (stop_analyze) {
break;
}
aliveLocals = GetAliveLocals(graph_);
@ -320,6 +320,41 @@ void GraphAnalyzer::UseDefAnalyze() {
graph_->SetOldBreakBci(graph_->GetStopTraceBci());
}
void MindGraphAnalyzer::UpdateCapturedOrder() {
const auto &traced_nodes = graph_->GetTracedNodes();
auto stop_bci = graph_->GetStopTraceBci();
if (stop_bci == -1) {
GetCaptureInfo().captured_locals.order = traced_nodes;
} else {
GetCaptureInfo().captured_locals.order.clear();
for (const auto &traced_node : traced_nodes) {
if (traced_node->bci() >= stop_bci) {
break;
}
GetCaptureInfo().captured_locals.order.push_back(traced_node);
}
}
const auto &captured_local_order = GetCaptureInfo().captured_locals.order;
std::set<ValueNode *> new_capture_local_values(captured_local_order.begin(), captured_local_order.end());
GetCaptureInfo().captured_locals.values = new_capture_local_values;
}
void MindGraphAnalyzer::UseDefAnalyze() {
// UD analyze: alive nodes analysis
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
if (!aliveLocals.empty()) {
bool stop_analyze = false;
while (!stop_analyze) {
UpdateCapturedOrder();
// Add graph output according to leaf nodes.
stop_analyze = AnalyzeAliveLocals(aliveLocals);
if (!stop_analyze) {
aliveLocals = GetAliveLocals(graph_);
}
}
}
}
void GraphAnalyzer::Analyze() {
const FrameStates &enter_frame = graph_->GetFrame(0);
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
@ -367,10 +402,47 @@ void GraphAnalyzer::Analyze() {
}
}
void MindGraphAnalyzer::CollectInputs() {
auto &values = GetCaptureInfo().captured_locals.values;
auto &inputs = GetCaptureInfo().captured_locals.inputs;
for (ValueNode *i : GetCaptureInfo().captured_locals.order) {
for (auto input : i->getInputs()) {
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
continue;
}
inputs.insert(input);
}
}
}
void MindGraphAnalyzer::Analyze() {
auto origin_stop_bci = graph_->GetStopTraceBci();
UseDefAnalyze();
CollectInputs();
const FrameStates &enter_frame = graph_->GetFrame(0);
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
MS_EXCEPTION_IF_NULL(mind_graph_builder);
auto func_graph_builder = mind_graph_builder->FGBuilder();
if (func_graph_builder->graph() == nullptr) {
// Graph build failed, add all nodes to ordered_escaped_locals.
MS_LOG(DEBUG) << "Failed to build graph";
GetCaptureInfo().ordered_escaped_locals.clear();
for (const auto &traced_node : graph_->GetTracedNodes()) {
if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
break;
}
AddToEscaped(traced_node);
}
need_interpret_ = true;
GetCaptureInfo().captured_locals.order.clear();
GetCaptureInfo().captured_locals.values.clear();
GetCaptureInfo().captured_locals.inputs.clear();
return;
}
need_interpret_ = true;
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
return;
@ -390,6 +462,10 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
continue;
}
auto capturedLocals = info_.captured_locals.order;
if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
continue;
}
AObject *o = node->GetVobj();
auto out_py_obj = o->GetPyObject();
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
@ -404,7 +480,11 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
isAllNodesSupportOutput = false;
int new_break_point = node->bci();
auto curNode = node;
MS_EXCEPTION_IF_CHECK_FAIL(new_break_point != -1, "break point cannot be -1");
if (new_break_point == -1) {
// No node is unsupported output since no node in captured output.
isAllNodesSupportOutput = true;
break;
}
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);

View File

@ -60,9 +60,12 @@ class GraphAnalyzer {
virtual bool NeedInterpret() const { return need_interpret_; }
protected:
void AddToEscaped(ValueNode *value);
// UD analyze
void UseDefAnalyze();
void CollectInputs();
virtual void UseDefAnalyze();
std::vector<ValueNode *> GetAliveLocals(Graph *g);
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
virtual void CollectInputs();
bool need_interpret_;
Graph *graph_;
CapturedInfo info_;
@ -73,11 +76,8 @@ class GraphAnalyzer {
bool TryToCapture(AbstractNode *value);
bool AddToCaptured(ValueNode *value);
bool HandleCallableToGraph(AObject *f);
void AddToEscaped(ValueNode *value);
bool ProduceInterpretValue(ValueNode *v);
void CleanCapturedValue();
std::vector<ValueNode *> GetAliveLocals(Graph *g);
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
void ClearCapturedInfo();
};
@ -87,6 +87,10 @@ class MindGraphAnalyzer : public GraphAnalyzer {
void Analyze() override;
protected:
// UD analyze
void UseDefAnalyze() override;
void CollectInputs() override;
void UpdateCapturedOrder();
bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override;
GraphBuilderPtr graph_builder_ = nullptr;
};