Decouple registring passes from specifying argument/description

This patch changes the (not recommended) static registration API from:

 static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");

to:

 static PassRegistration<MyPass> reg;

And the explicit registration from:

  void registerPass("my-pass", "My Pass Description.",
                    [] { return createMyPass(); });

To:

  void registerPass([] { return createMyPass(); });

It is expected that Pass implementations overrides the getArgument() method
instead. This will ensure that pipeline description can be printed and parsed
back.

Differential Revision: https://reviews.llvm.org/D104421
This commit is contained in:
Mehdi Amini 2021-06-16 23:41:23 +00:00
parent fc4f457fcc
commit c8a3f561eb
8 changed files with 56 additions and 25 deletions

View File

@ -86,8 +86,7 @@ struct MyFunctionPass : public PassWrapper<MyFunctionPass,
/// Register this pass so that it can be built via from a textual pass pipeline.
/// (Pass registration is discussed more below)
void registerMyPass() {
PassRegistration<MyFunctionPass>(
"flag-name-to-invoke-pass-via-mlir-opt", "Pass description here");
PassRegistration<MyFunctionPass>();
}
```
@ -503,7 +502,15 @@ struct MyPass ... {
/// ensure that the options are initialized properly.
MyPass() = default;
MyPass(const MyPass& pass) {}
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "argument";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "description";
}
/// Define the statistic to track during the execution of MyPass.
Statistic exampleStat{this, "exampleStat", "An example statistic"};
@ -562,21 +569,22 @@ example registration is shown below:
```c++
void registerMyPass() {
PassRegistration<MyPass>("argument", "description");
PassRegistration<MyPass>();
}
```
* `MyPass` is the name of the derived pass class.
* "argument" is the argument used to refer to the pass in the textual format.
* "description" is a brief description of the pass.
* The pass `getArgument()` method is used to get the identifier that will be
used to refer to the pass.
* The pass `getDescription()` method provides a short summary describing the
pass.
For passes that cannot be default-constructed, `PassRegistration` accepts an
optional third argument that takes a callback to create the pass:
optional argument that takes a callback to create the pass:
```c++
void registerMyPass() {
PassRegistration<MyParametricPass>(
"argument", "description",
[]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> p = std::make_unique<MyParametricPass>(/*options*/);
/*... non-trivial-logic to configure the pass ...*/;
@ -710,7 +718,7 @@ std::unique_ptr<Pass> foo::createMyPass() {
/// Register this pass.
void foo::registerMyPass() {
PassRegistration<MyPass>("my-pass", "My pass summary");
PassRegistration<MyPass>();
}
```

View File

@ -73,10 +73,14 @@ public:
/// register the Affine dialect but does not need to register Linalg.
virtual void getDependentDialects(DialectRegistry &registry) const {}
/// Returns the command line argument used when registering this pass. Return
/// Return the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const { return ""; }
/// Return the command line description used when registering this pass.
/// Return an empty string if one does not exist.
virtual StringRef getDescription() const { return ""; }
/// Returns the name of the operation that this pass operates on, or None if
/// this is a generic OperationPass.
Optional<StringRef> getOpName() const { return opName; }

View File

@ -125,20 +125,33 @@ void registerPassPipeline(
/// Register a specific dialect pass allocator function with the system,
/// typically used through the PassRegistration template.
/// Deprecated: please use the alternate version below.
void registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function);
/// Register a specific dialect pass allocator function with the system,
/// typically used through the PassRegistration template.
void registerPass(const PassAllocatorFunction &function);
/// PassRegistration provides a global initializer that registers a Pass
/// allocation routine for a concrete pass instance. The third argument is
/// allocation routine for a concrete pass instance. The argument is
/// optional and provides a callback to construct a pass that does not have
/// a default constructor.
///
/// Usage:
///
/// /// At namespace scope.
/// static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
/// static PassRegistration<MyPass> reg;
///
template <typename ConcretePass> struct PassRegistration {
PassRegistration(const PassAllocatorFunction &constructor) {
registerPass(constructor);
}
PassRegistration()
: PassRegistration([] { return std::make_unique<ConcretePass>(); }) {}
/// Constructor below are deprecated.
PassRegistration(StringRef arg, StringRef description,
const PassAllocatorFunction &constructor) {
registerPass(arg, description, constructor);

View File

@ -622,11 +622,6 @@ def PrintOpStats : Pass<"print-op-stats"> {
let constructor = "mlir::createPrintOpStatsPass()";
}
def PrintOp : Pass<"print-op-graph", "ModuleOp"> {
let summary = "Print op graph per-Region";
let constructor = "mlir::createPrintOpGraphPass()";
}
def SCCP : Pass<"sccp"> {
let summary = "Sparse Conditional Constant Propagation";
let description = [{

View File

@ -122,6 +122,15 @@ void mlir::registerPass(StringRef arg, StringRef description,
}
}
void mlir::registerPass(const PassAllocatorFunction &function) {
std::unique_ptr<Pass> pass = function();
StringRef arg = pass->getArgument();
if (arg.empty())
llvm::report_fatal_error(
"Trying to register a pass that does not override `getArgument()`");
registerPass(arg, pass->getDescription(), function);
}
/// Returns the pass info for the specified pass argument or null if unknown.
const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
auto it = passRegistry->find(passArg);

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
// CHECK-LABEL: digraph "merge_blocks"
// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}

View File

@ -71,10 +71,10 @@ run(testParseFail)
def testInvalidNesting():
with Context():
try:
pm = PassManager.parse("func(print-op-graph)")
pm = PassManager.parse("func(view-op-graph)")
except ValueError as e:
# CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
# CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
# CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'.
log("ValueError exception:", e)
else:
log("Exception not produced")

View File

@ -56,6 +56,8 @@ public:
}
::llvm::StringRef getArgument() const override { return "{2}"; }
::llvm::StringRef getDescription() const override { return "{3}"; }
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("{0}");
@ -74,7 +76,7 @@ public:
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
{3}
{4}
}
protected:
@ -122,7 +124,8 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
dependentDialect);
}
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), dependentDialectRegistrations);
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
@ -154,8 +157,8 @@ const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
inline void register{0}Pass() {{
::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
return {3};
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
});
}
)";
@ -175,7 +178,6 @@ static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
pass.getArgument(), pass.getSummary(),
pass.getConstructor());
}