fix ud for one stage
This commit is contained in:
parent
bc9f94dd1b
commit
8fce13d6bc
|
@ -330,10 +330,24 @@ py::object FuncGraphBuilder::TryToAddNode(const ValuePtr &callable_value, const
|
|||
|
||||
auto new_node = graph_->NewCNode(input_node_list);
|
||||
// Return the converted python object.
|
||||
py::object output_py_obj = ConvertToPyObj(abs);
|
||||
if (output_py_obj.ptr() == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
|
||||
return py::object();
|
||||
py::object output_py_obj;
|
||||
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
auto fg = abs_func->func_graph();
|
||||
if (fg == nullptr) {
|
||||
return py::object();
|
||||
}
|
||||
auto obj = fg->python_obj();
|
||||
if (obj == nullptr || !obj->isa<parse::PyObjectWrapper>()) {
|
||||
return py::object();
|
||||
}
|
||||
output_py_obj = obj->cast_ptr<parse::PyObjectWrapper>()->obj();
|
||||
} else {
|
||||
output_py_obj = ConvertToPyObj(abs);
|
||||
if (output_py_obj.ptr() == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
|
||||
return py::object();
|
||||
}
|
||||
}
|
||||
|
||||
new_node->set_abstract(abs);
|
||||
|
|
|
@ -1570,31 +1570,6 @@ py::object MakeCodeFromCodeGen(const GraphBuilderPtr &builder, const GraphAnalyz
|
|||
GraphAnalyzer::CapturedInfo info = analyzer->GetCaptureInfo();
|
||||
auto cg = CodeBreakGenerator::Creator(builder, graph->GetCodeObj());
|
||||
if (builder->trace_flag()) {
|
||||
if (analyzer->NeedInterpret()) {
|
||||
std::vector<ValueNode *> new_capture_locals;
|
||||
const auto &traced_nodes = graph->GetTracedNodes();
|
||||
auto break_bci = graph->GetStopTraceBci();
|
||||
for (const auto &node : traced_nodes) {
|
||||
if (node->bci() >= break_bci) {
|
||||
break;
|
||||
}
|
||||
new_capture_locals.push_back(node);
|
||||
}
|
||||
info.captured_locals.order = new_capture_locals;
|
||||
// Collect inputs for break capture graph.
|
||||
auto &values = info.captured_locals.values;
|
||||
auto &inputs = info.captured_locals.inputs;
|
||||
for (ValueNode *i : info.captured_locals.order) {
|
||||
for (auto input : i->getInputs()) {
|
||||
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
|
||||
continue;
|
||||
}
|
||||
inputs.insert(input);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info.captured_locals.order = graph->GetTracedNodes();
|
||||
}
|
||||
auto mind_builder = std::dynamic_pointer_cast<MindGraphBuilder>(builder);
|
||||
auto mind_fg_builder = mind_builder->FGBuilder();
|
||||
MS_EXCEPTION_IF_NULL(mind_fg_builder);
|
||||
|
|
|
@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() {
|
|||
// UD analyze: alive nodes analysis
|
||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||
if (!aliveLocals.empty()) {
|
||||
bool isStopAnalyze = false;
|
||||
while (!isStopAnalyze) {
|
||||
isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (isStopAnalyze) {
|
||||
bool stop_analyze = false;
|
||||
while (!stop_analyze) {
|
||||
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (stop_analyze) {
|
||||
break;
|
||||
}
|
||||
aliveLocals = GetAliveLocals(graph_);
|
||||
|
@ -320,6 +320,41 @@ void GraphAnalyzer::UseDefAnalyze() {
|
|||
graph_->SetOldBreakBci(graph_->GetStopTraceBci());
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UpdateCapturedOrder() {
|
||||
const auto &traced_nodes = graph_->GetTracedNodes();
|
||||
auto stop_bci = graph_->GetStopTraceBci();
|
||||
if (stop_bci == -1) {
|
||||
GetCaptureInfo().captured_locals.order = traced_nodes;
|
||||
} else {
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
for (const auto &traced_node : traced_nodes) {
|
||||
if (traced_node->bci() >= stop_bci) {
|
||||
break;
|
||||
}
|
||||
GetCaptureInfo().captured_locals.order.push_back(traced_node);
|
||||
}
|
||||
}
|
||||
const auto &captured_local_order = GetCaptureInfo().captured_locals.order;
|
||||
std::set<ValueNode *> new_capture_local_values(captured_local_order.begin(), captured_local_order.end());
|
||||
GetCaptureInfo().captured_locals.values = new_capture_local_values;
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::UseDefAnalyze() {
|
||||
// UD analyze: alive nodes analysis
|
||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||
if (!aliveLocals.empty()) {
|
||||
bool stop_analyze = false;
|
||||
while (!stop_analyze) {
|
||||
UpdateCapturedOrder();
|
||||
// Add graph output according to leaf nodes.
|
||||
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||
if (!stop_analyze) {
|
||||
aliveLocals = GetAliveLocals(graph_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphAnalyzer::Analyze() {
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||
|
@ -367,10 +402,47 @@ void GraphAnalyzer::Analyze() {
|
|||
}
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::CollectInputs() {
|
||||
auto &values = GetCaptureInfo().captured_locals.values;
|
||||
auto &inputs = GetCaptureInfo().captured_locals.inputs;
|
||||
for (ValueNode *i : GetCaptureInfo().captured_locals.order) {
|
||||
for (auto input : i->getInputs()) {
|
||||
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
|
||||
continue;
|
||||
}
|
||||
inputs.insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindGraphAnalyzer::Analyze() {
|
||||
auto origin_stop_bci = graph_->GetStopTraceBci();
|
||||
UseDefAnalyze();
|
||||
CollectInputs();
|
||||
|
||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||
if (func_graph_builder->graph() == nullptr) {
|
||||
// Graph build failed, add all nodes to ordered_escaped_locals.
|
||||
MS_LOG(DEBUG) << "Failed to build graph";
|
||||
GetCaptureInfo().ordered_escaped_locals.clear();
|
||||
for (const auto &traced_node : graph_->GetTracedNodes()) {
|
||||
if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
|
||||
break;
|
||||
}
|
||||
AddToEscaped(traced_node);
|
||||
}
|
||||
need_interpret_ = true;
|
||||
GetCaptureInfo().captured_locals.order.clear();
|
||||
GetCaptureInfo().captured_locals.values.clear();
|
||||
GetCaptureInfo().captured_locals.inputs.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
need_interpret_ = true;
|
||||
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
|
||||
return;
|
||||
|
@ -390,6 +462,10 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
|
|||
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
|
||||
continue;
|
||||
}
|
||||
auto capturedLocals = info_.captured_locals.order;
|
||||
if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
|
||||
continue;
|
||||
}
|
||||
AObject *o = node->GetVobj();
|
||||
auto out_py_obj = o->GetPyObject();
|
||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||
|
@ -404,7 +480,11 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
|
|||
isAllNodesSupportOutput = false;
|
||||
int new_break_point = node->bci();
|
||||
auto curNode = node;
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(new_break_point != -1, "break point cannot be -1");
|
||||
if (new_break_point == -1) {
|
||||
// No node is unsupported output since no node in captured output.
|
||||
isAllNodesSupportOutput = true;
|
||||
break;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
|
||||
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
|
||||
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
|
||||
|
|
|
@ -60,9 +60,12 @@ class GraphAnalyzer {
|
|||
virtual bool NeedInterpret() const { return need_interpret_; }
|
||||
|
||||
protected:
|
||||
void AddToEscaped(ValueNode *value);
|
||||
// UD analyze
|
||||
void UseDefAnalyze();
|
||||
void CollectInputs();
|
||||
virtual void UseDefAnalyze();
|
||||
std::vector<ValueNode *> GetAliveLocals(Graph *g);
|
||||
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
|
||||
virtual void CollectInputs();
|
||||
bool need_interpret_;
|
||||
Graph *graph_;
|
||||
CapturedInfo info_;
|
||||
|
@ -73,11 +76,8 @@ class GraphAnalyzer {
|
|||
bool TryToCapture(AbstractNode *value);
|
||||
bool AddToCaptured(ValueNode *value);
|
||||
bool HandleCallableToGraph(AObject *f);
|
||||
void AddToEscaped(ValueNode *value);
|
||||
bool ProduceInterpretValue(ValueNode *v);
|
||||
void CleanCapturedValue();
|
||||
std::vector<ValueNode *> GetAliveLocals(Graph *g);
|
||||
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
|
||||
void ClearCapturedInfo();
|
||||
};
|
||||
|
||||
|
@ -87,6 +87,10 @@ class MindGraphAnalyzer : public GraphAnalyzer {
|
|||
void Analyze() override;
|
||||
|
||||
protected:
|
||||
// UD analyze
|
||||
void UseDefAnalyze() override;
|
||||
void CollectInputs() override;
|
||||
void UpdateCapturedOrder();
|
||||
bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override;
|
||||
GraphBuilderPtr graph_builder_ = nullptr;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue