fix ud for one stage
This commit is contained in:
parent
bc9f94dd1b
commit
8fce13d6bc
|
@ -330,10 +330,24 @@ py::object FuncGraphBuilder::TryToAddNode(const ValuePtr &callable_value, const
|
||||||
|
|
||||||
auto new_node = graph_->NewCNode(input_node_list);
|
auto new_node = graph_->NewCNode(input_node_list);
|
||||||
// Return the converted python object.
|
// Return the converted python object.
|
||||||
py::object output_py_obj = ConvertToPyObj(abs);
|
py::object output_py_obj;
|
||||||
if (output_py_obj.ptr() == nullptr) {
|
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||||
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
|
auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||||
return py::object();
|
auto fg = abs_func->func_graph();
|
||||||
|
if (fg == nullptr) {
|
||||||
|
return py::object();
|
||||||
|
}
|
||||||
|
auto obj = fg->python_obj();
|
||||||
|
if (obj == nullptr || !obj->isa<parse::PyObjectWrapper>()) {
|
||||||
|
return py::object();
|
||||||
|
}
|
||||||
|
output_py_obj = obj->cast_ptr<parse::PyObjectWrapper>()->obj();
|
||||||
|
} else {
|
||||||
|
output_py_obj = ConvertToPyObj(abs);
|
||||||
|
if (output_py_obj.ptr() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Convert abs " << abs->ToString() << " to python object failed.";
|
||||||
|
return py::object();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
new_node->set_abstract(abs);
|
new_node->set_abstract(abs);
|
||||||
|
|
|
@ -1570,31 +1570,6 @@ py::object MakeCodeFromCodeGen(const GraphBuilderPtr &builder, const GraphAnalyz
|
||||||
GraphAnalyzer::CapturedInfo info = analyzer->GetCaptureInfo();
|
GraphAnalyzer::CapturedInfo info = analyzer->GetCaptureInfo();
|
||||||
auto cg = CodeBreakGenerator::Creator(builder, graph->GetCodeObj());
|
auto cg = CodeBreakGenerator::Creator(builder, graph->GetCodeObj());
|
||||||
if (builder->trace_flag()) {
|
if (builder->trace_flag()) {
|
||||||
if (analyzer->NeedInterpret()) {
|
|
||||||
std::vector<ValueNode *> new_capture_locals;
|
|
||||||
const auto &traced_nodes = graph->GetTracedNodes();
|
|
||||||
auto break_bci = graph->GetStopTraceBci();
|
|
||||||
for (const auto &node : traced_nodes) {
|
|
||||||
if (node->bci() >= break_bci) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
new_capture_locals.push_back(node);
|
|
||||||
}
|
|
||||||
info.captured_locals.order = new_capture_locals;
|
|
||||||
// Collect inputs for break capture graph.
|
|
||||||
auto &values = info.captured_locals.values;
|
|
||||||
auto &inputs = info.captured_locals.inputs;
|
|
||||||
for (ValueNode *i : info.captured_locals.order) {
|
|
||||||
for (auto input : i->getInputs()) {
|
|
||||||
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
inputs.insert(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
info.captured_locals.order = graph->GetTracedNodes();
|
|
||||||
}
|
|
||||||
auto mind_builder = std::dynamic_pointer_cast<MindGraphBuilder>(builder);
|
auto mind_builder = std::dynamic_pointer_cast<MindGraphBuilder>(builder);
|
||||||
auto mind_fg_builder = mind_builder->FGBuilder();
|
auto mind_fg_builder = mind_builder->FGBuilder();
|
||||||
MS_EXCEPTION_IF_NULL(mind_fg_builder);
|
MS_EXCEPTION_IF_NULL(mind_fg_builder);
|
||||||
|
|
|
@ -308,10 +308,10 @@ void GraphAnalyzer::UseDefAnalyze() {
|
||||||
// UD analyze: alive nodes analysis
|
// UD analyze: alive nodes analysis
|
||||||
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||||
if (!aliveLocals.empty()) {
|
if (!aliveLocals.empty()) {
|
||||||
bool isStopAnalyze = false;
|
bool stop_analyze = false;
|
||||||
while (!isStopAnalyze) {
|
while (!stop_analyze) {
|
||||||
isStopAnalyze = AnalyzeAliveLocals(aliveLocals);
|
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||||
if (isStopAnalyze) {
|
if (stop_analyze) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
aliveLocals = GetAliveLocals(graph_);
|
aliveLocals = GetAliveLocals(graph_);
|
||||||
|
@ -320,6 +320,41 @@ void GraphAnalyzer::UseDefAnalyze() {
|
||||||
graph_->SetOldBreakBci(graph_->GetStopTraceBci());
|
graph_->SetOldBreakBci(graph_->GetStopTraceBci());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MindGraphAnalyzer::UpdateCapturedOrder() {
|
||||||
|
const auto &traced_nodes = graph_->GetTracedNodes();
|
||||||
|
auto stop_bci = graph_->GetStopTraceBci();
|
||||||
|
if (stop_bci == -1) {
|
||||||
|
GetCaptureInfo().captured_locals.order = traced_nodes;
|
||||||
|
} else {
|
||||||
|
GetCaptureInfo().captured_locals.order.clear();
|
||||||
|
for (const auto &traced_node : traced_nodes) {
|
||||||
|
if (traced_node->bci() >= stop_bci) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
GetCaptureInfo().captured_locals.order.push_back(traced_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const auto &captured_local_order = GetCaptureInfo().captured_locals.order;
|
||||||
|
std::set<ValueNode *> new_capture_local_values(captured_local_order.begin(), captured_local_order.end());
|
||||||
|
GetCaptureInfo().captured_locals.values = new_capture_local_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindGraphAnalyzer::UseDefAnalyze() {
|
||||||
|
// UD analyze: alive nodes analysis
|
||||||
|
std::vector<ValueNode *> aliveLocals = GetAliveLocals(graph_);
|
||||||
|
if (!aliveLocals.empty()) {
|
||||||
|
bool stop_analyze = false;
|
||||||
|
while (!stop_analyze) {
|
||||||
|
UpdateCapturedOrder();
|
||||||
|
// Add graph output according to leaf nodes.
|
||||||
|
stop_analyze = AnalyzeAliveLocals(aliveLocals);
|
||||||
|
if (!stop_analyze) {
|
||||||
|
aliveLocals = GetAliveLocals(graph_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void GraphAnalyzer::Analyze() {
|
void GraphAnalyzer::Analyze() {
|
||||||
const FrameStates &enter_frame = graph_->GetFrame(0);
|
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||||
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||||
|
@ -367,10 +402,47 @@ void GraphAnalyzer::Analyze() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MindGraphAnalyzer::CollectInputs() {
|
||||||
|
auto &values = GetCaptureInfo().captured_locals.values;
|
||||||
|
auto &inputs = GetCaptureInfo().captured_locals.inputs;
|
||||||
|
for (ValueNode *i : GetCaptureInfo().captured_locals.order) {
|
||||||
|
for (auto input : i->getInputs()) {
|
||||||
|
if (values.find(input) != values.end() || IsNonLocalValue(input)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
inputs.insert(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MindGraphAnalyzer::Analyze() {
|
void MindGraphAnalyzer::Analyze() {
|
||||||
|
auto origin_stop_bci = graph_->GetStopTraceBci();
|
||||||
UseDefAnalyze();
|
UseDefAnalyze();
|
||||||
CollectInputs();
|
CollectInputs();
|
||||||
|
|
||||||
|
const FrameStates &enter_frame = graph_->GetFrame(0);
|
||||||
|
GetCaptureInfo().escaped_locals.insert(enter_frame.GetLocals().begin(), enter_frame.GetLocals().end());
|
||||||
|
|
||||||
|
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||||
|
MS_EXCEPTION_IF_NULL(mind_graph_builder);
|
||||||
|
auto func_graph_builder = mind_graph_builder->FGBuilder();
|
||||||
|
if (func_graph_builder->graph() == nullptr) {
|
||||||
|
// Graph build failed, add all nodes to ordered_escaped_locals.
|
||||||
|
MS_LOG(DEBUG) << "Failed to build graph";
|
||||||
|
GetCaptureInfo().ordered_escaped_locals.clear();
|
||||||
|
for (const auto &traced_node : graph_->GetTracedNodes()) {
|
||||||
|
if (origin_stop_bci != -1 && traced_node->bci() >= origin_stop_bci) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
AddToEscaped(traced_node);
|
||||||
|
}
|
||||||
|
need_interpret_ = true;
|
||||||
|
GetCaptureInfo().captured_locals.order.clear();
|
||||||
|
GetCaptureInfo().captured_locals.values.clear();
|
||||||
|
GetCaptureInfo().captured_locals.inputs.clear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
need_interpret_ = true;
|
need_interpret_ = true;
|
||||||
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
|
if (graph_->GetStopTraceBci() != -1 || !GetCaptureInfo().ordered_escaped_locals.empty()) {
|
||||||
return;
|
return;
|
||||||
|
@ -390,6 +462,10 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
|
||||||
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
|
MS_LOG(DEBUG) << "Skip non local value used as graph return.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto capturedLocals = info_.captured_locals.order;
|
||||||
|
if (std::find(capturedLocals.begin(), capturedLocals.end(), node) == capturedLocals.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
AObject *o = node->GetVobj();
|
AObject *o = node->GetVobj();
|
||||||
auto out_py_obj = o->GetPyObject();
|
auto out_py_obj = o->GetPyObject();
|
||||||
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
auto mind_graph_builder = std::static_pointer_cast<MindGraphBuilder>(graph_builder_);
|
||||||
|
@ -404,7 +480,11 @@ bool MindGraphAnalyzer::AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes)
|
||||||
isAllNodesSupportOutput = false;
|
isAllNodesSupportOutput = false;
|
||||||
int new_break_point = node->bci();
|
int new_break_point = node->bci();
|
||||||
auto curNode = node;
|
auto curNode = node;
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(new_break_point != -1, "break point cannot be -1");
|
if (new_break_point == -1) {
|
||||||
|
// No node is unsupported output since no node in captured output.
|
||||||
|
isAllNodesSupportOutput = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
|
MS_EXCEPTION_IF_NULL(curNode->GetGraph());
|
||||||
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
|
if (this->graph_->Config().GetBoolConfig(GraphJitConfig::kLogGraphBreak)) {
|
||||||
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
|
GRAPH_JIT_LOG_F("reset break point: %d", new_break_point);
|
||||||
|
|
|
@ -60,9 +60,12 @@ class GraphAnalyzer {
|
||||||
virtual bool NeedInterpret() const { return need_interpret_; }
|
virtual bool NeedInterpret() const { return need_interpret_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void AddToEscaped(ValueNode *value);
|
||||||
// UD analyze
|
// UD analyze
|
||||||
void UseDefAnalyze();
|
virtual void UseDefAnalyze();
|
||||||
void CollectInputs();
|
std::vector<ValueNode *> GetAliveLocals(Graph *g);
|
||||||
|
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
|
||||||
|
virtual void CollectInputs();
|
||||||
bool need_interpret_;
|
bool need_interpret_;
|
||||||
Graph *graph_;
|
Graph *graph_;
|
||||||
CapturedInfo info_;
|
CapturedInfo info_;
|
||||||
|
@ -73,11 +76,8 @@ class GraphAnalyzer {
|
||||||
bool TryToCapture(AbstractNode *value);
|
bool TryToCapture(AbstractNode *value);
|
||||||
bool AddToCaptured(ValueNode *value);
|
bool AddToCaptured(ValueNode *value);
|
||||||
bool HandleCallableToGraph(AObject *f);
|
bool HandleCallableToGraph(AObject *f);
|
||||||
void AddToEscaped(ValueNode *value);
|
|
||||||
bool ProduceInterpretValue(ValueNode *v);
|
bool ProduceInterpretValue(ValueNode *v);
|
||||||
void CleanCapturedValue();
|
void CleanCapturedValue();
|
||||||
std::vector<ValueNode *> GetAliveLocals(Graph *g);
|
|
||||||
virtual bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes);
|
|
||||||
void ClearCapturedInfo();
|
void ClearCapturedInfo();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -87,6 +87,10 @@ class MindGraphAnalyzer : public GraphAnalyzer {
|
||||||
void Analyze() override;
|
void Analyze() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// UD analyze
|
||||||
|
void UseDefAnalyze() override;
|
||||||
|
void CollectInputs() override;
|
||||||
|
void UpdateCapturedOrder();
|
||||||
bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override;
|
bool AnalyzeAliveLocals(std::vector<ValueNode *> aliveNodes) override;
|
||||||
GraphBuilderPtr graph_builder_ = nullptr;
|
GraphBuilderPtr graph_builder_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue