From 02d7b260c697adb20e9a12d36da49ee88eb714a5 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 20 Feb 2021 15:42:02 -0800 Subject: [PATCH] [mlir] Register the print-op-graph pass using ODS Move over to ODS & use pass options. --- mlir/include/mlir/Transforms/Passes.td | 18 ++++++++++++++++++ mlir/lib/Transforms/ViewOpGraph.cpp | 14 +++++++------- mlir/test/Transforms/print-op-graph.mlir | 12 ++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 mlir/test/Transforms/print-op-graph.mlir diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a03b439af339..925e7e79841d 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -730,4 +730,22 @@ def SymbolDCE : Pass<"symbol-dce"> { }]; let constructor = "mlir::createSymbolDCEPass()"; } + +def ViewOpGraphPass : Pass<"symbol-dce", "ModuleOp"> { + let summary = "Print graphviz view of module"; + let description = [{ + This pass prints a graphviz per block of a module. + + - Op are represented as nodes; + - Uses as edges; + }]; + let constructor = "mlir::createPrintOpGraphPass()"; + let options = [ + Option<"title", "title", "std::string", + /*default=*/"", "The prefix of the title of the graph">, + Option<"shortNames", "short-names", "bool", /*default=*/"false", + "Use short names"> + ]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 97fe7b2d45bb..3d52d79b7ef7 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -104,10 +104,12 @@ namespace { // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public PrintOpBase { - explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false, - const Twine &title = "") - : os(os), title(title.str()), short_names(short_names) {} +class PrintOpPass : public ViewOpGraphPassBase { +public: + PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) { + this->shortNames = shortNames; + this->title = title.str(); + } std::string getOpName(Operation &op) { auto symbolAttr = @@ -133,7 +135,7 @@ struct PrintOpPass : public PrintOpBase { auto blockName = llvm::hasSingleElement(region) ? "" : ("__" + llvm::utostr(indexed_block.index())); - llvm::WriteGraph(os, &indexed_block.value(), short_names, + llvm::WriteGraph(os, &indexed_block.value(), shortNames, Twine(title) + opName + blockName); } } @@ -144,9 +146,7 @@ struct PrintOpPass : public PrintOpBase { private: raw_ostream &os; - std::string title; int unnamedOpCtr = 0; - bool short_names; }; } // namespace diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir new file mode 100644 index 000000000000..1c4548e931ca --- /dev/null +++ b/mlir/test/Transforms/print-op-graph.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt -allow-unregistered-dialect -print-op-graph %s -o %t 2>&1 | FileCheck %s + +// CHECK-LABEL: digraph "merge_blocks" +func @merge_blocks(%arg0: i32, %arg1 : i32) -> () { + %0:2 = "test.merge_blocks"() ({ + ^bb0: + "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> () + ^bb1(%arg3 : i32, %arg4 : i32): + "test.return"(%arg3, %arg4) : (i32, i32) -> () + }) : () -> (i32, i32) + "test.return"(%0#0, %0#1) : (i32, i32) -> () +}