forked from OSchip/llvm-project
Decouple registring passes from specifying argument/description
This patch changes the (not recommended) static registration API from: static PassRegistration<MyPass> reg("my-pass", "My Pass Description."); to: static PassRegistration<MyPass> reg; And the explicit registration from: void registerPass("my-pass", "My Pass Description.", [] { return createMyPass(); }); To: void registerPass([] { return createMyPass(); }); It is expected that Pass implementations overrides the getArgument() method instead. This will ensure that pipeline description can be printed and parsed back. Differential Revision: https://reviews.llvm.org/D104421
This commit is contained in:
parent
fc4f457fcc
commit
c8a3f561eb
|
@ -86,8 +86,7 @@ struct MyFunctionPass : public PassWrapper<MyFunctionPass,
|
|||
/// Register this pass so that it can be built via from a textual pass pipeline.
|
||||
/// (Pass registration is discussed more below)
|
||||
void registerMyPass() {
|
||||
PassRegistration<MyFunctionPass>(
|
||||
"flag-name-to-invoke-pass-via-mlir-opt", "Pass description here");
|
||||
PassRegistration<MyFunctionPass>();
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -503,7 +502,15 @@ struct MyPass ... {
|
|||
/// ensure that the options are initialized properly.
|
||||
MyPass() = default;
|
||||
MyPass(const MyPass& pass) {}
|
||||
|
||||
StringRef getArgument() const final {
|
||||
// This is the argument used to refer to the pass in
|
||||
// the textual format (on the commandline for example).
|
||||
return "argument";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
// This is a brief description of the pass.
|
||||
return "description";
|
||||
}
|
||||
/// Define the statistic to track during the execution of MyPass.
|
||||
Statistic exampleStat{this, "exampleStat", "An example statistic"};
|
||||
|
||||
|
@ -562,21 +569,22 @@ example registration is shown below:
|
|||
|
||||
```c++
|
||||
void registerMyPass() {
|
||||
PassRegistration<MyPass>("argument", "description");
|
||||
PassRegistration<MyPass>();
|
||||
}
|
||||
```
|
||||
|
||||
* `MyPass` is the name of the derived pass class.
|
||||
* "argument" is the argument used to refer to the pass in the textual format.
|
||||
* "description" is a brief description of the pass.
|
||||
* The pass `getArgument()` method is used to get the identifier that will be
|
||||
used to refer to the pass.
|
||||
* The pass `getDescription()` method provides a short summary describing the
|
||||
pass.
|
||||
|
||||
For passes that cannot be default-constructed, `PassRegistration` accepts an
|
||||
optional third argument that takes a callback to create the pass:
|
||||
optional argument that takes a callback to create the pass:
|
||||
|
||||
```c++
|
||||
void registerMyPass() {
|
||||
PassRegistration<MyParametricPass>(
|
||||
"argument", "description",
|
||||
[]() -> std::unique_ptr<Pass> {
|
||||
std::unique_ptr<Pass> p = std::make_unique<MyParametricPass>(/*options*/);
|
||||
/*... non-trivial-logic to configure the pass ...*/;
|
||||
|
@ -710,7 +718,7 @@ std::unique_ptr<Pass> foo::createMyPass() {
|
|||
|
||||
/// Register this pass.
|
||||
void foo::registerMyPass() {
|
||||
PassRegistration<MyPass>("my-pass", "My pass summary");
|
||||
PassRegistration<MyPass>();
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -73,10 +73,14 @@ public:
|
|||
/// register the Affine dialect but does not need to register Linalg.
|
||||
virtual void getDependentDialects(DialectRegistry ®istry) const {}
|
||||
|
||||
/// Returns the command line argument used when registering this pass. Return
|
||||
/// Return the command line argument used when registering this pass. Return
|
||||
/// an empty string if one does not exist.
|
||||
virtual StringRef getArgument() const { return ""; }
|
||||
|
||||
/// Return the command line description used when registering this pass.
|
||||
/// Return an empty string if one does not exist.
|
||||
virtual StringRef getDescription() const { return ""; }
|
||||
|
||||
/// Returns the name of the operation that this pass operates on, or None if
|
||||
/// this is a generic OperationPass.
|
||||
Optional<StringRef> getOpName() const { return opName; }
|
||||
|
|
|
@ -125,20 +125,33 @@ void registerPassPipeline(
|
|||
|
||||
/// Register a specific dialect pass allocator function with the system,
|
||||
/// typically used through the PassRegistration template.
|
||||
/// Deprecated: please use the alternate version below.
|
||||
void registerPass(StringRef arg, StringRef description,
|
||||
const PassAllocatorFunction &function);
|
||||
|
||||
/// Register a specific dialect pass allocator function with the system,
|
||||
/// typically used through the PassRegistration template.
|
||||
void registerPass(const PassAllocatorFunction &function);
|
||||
|
||||
/// PassRegistration provides a global initializer that registers a Pass
|
||||
/// allocation routine for a concrete pass instance. The third argument is
|
||||
/// allocation routine for a concrete pass instance. The argument is
|
||||
/// optional and provides a callback to construct a pass that does not have
|
||||
/// a default constructor.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// /// At namespace scope.
|
||||
/// static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
|
||||
/// static PassRegistration<MyPass> reg;
|
||||
///
|
||||
template <typename ConcretePass> struct PassRegistration {
|
||||
PassRegistration(const PassAllocatorFunction &constructor) {
|
||||
registerPass(constructor);
|
||||
}
|
||||
PassRegistration()
|
||||
: PassRegistration([] { return std::make_unique<ConcretePass>(); }) {}
|
||||
|
||||
/// Constructor below are deprecated.
|
||||
|
||||
PassRegistration(StringRef arg, StringRef description,
|
||||
const PassAllocatorFunction &constructor) {
|
||||
registerPass(arg, description, constructor);
|
||||
|
|
|
@ -622,11 +622,6 @@ def PrintOpStats : Pass<"print-op-stats"> {
|
|||
let constructor = "mlir::createPrintOpStatsPass()";
|
||||
}
|
||||
|
||||
def PrintOp : Pass<"print-op-graph", "ModuleOp"> {
|
||||
let summary = "Print op graph per-Region";
|
||||
let constructor = "mlir::createPrintOpGraphPass()";
|
||||
}
|
||||
|
||||
def SCCP : Pass<"sccp"> {
|
||||
let summary = "Sparse Conditional Constant Propagation";
|
||||
let description = [{
|
||||
|
|
|
@ -122,6 +122,15 @@ void mlir::registerPass(StringRef arg, StringRef description,
|
|||
}
|
||||
}
|
||||
|
||||
void mlir::registerPass(const PassAllocatorFunction &function) {
|
||||
std::unique_ptr<Pass> pass = function();
|
||||
StringRef arg = pass->getArgument();
|
||||
if (arg.empty())
|
||||
llvm::report_fatal_error(
|
||||
"Trying to register a pass that does not override `getArgument()`");
|
||||
registerPass(arg, pass->getDescription(), function);
|
||||
}
|
||||
|
||||
/// Returns the pass info for the specified pass argument or null if unknown.
|
||||
const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
|
||||
auto it = passRegistry->find(passArg);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: digraph "merge_blocks"
|
||||
// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}
|
||||
|
|
|
@ -71,10 +71,10 @@ run(testParseFail)
|
|||
def testInvalidNesting():
|
||||
with Context():
|
||||
try:
|
||||
pm = PassManager.parse("func(print-op-graph)")
|
||||
pm = PassManager.parse("func(view-op-graph)")
|
||||
except ValueError as e:
|
||||
# CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
|
||||
# CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
|
||||
# CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'.
|
||||
log("ValueError exception:", e)
|
||||
else:
|
||||
log("Exception not produced")
|
||||
|
|
|
@ -56,6 +56,8 @@ public:
|
|||
}
|
||||
::llvm::StringRef getArgument() const override { return "{2}"; }
|
||||
|
||||
::llvm::StringRef getDescription() const override { return "{3}"; }
|
||||
|
||||
/// Returns the derived pass name.
|
||||
static constexpr ::llvm::StringLiteral getPassName() {
|
||||
return ::llvm::StringLiteral("{0}");
|
||||
|
@ -74,7 +76,7 @@ public:
|
|||
|
||||
/// Return the dialect that must be loaded in the context before this pass.
|
||||
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
||||
{3}
|
||||
{4}
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -122,7 +124,8 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
|||
dependentDialect);
|
||||
}
|
||||
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
|
||||
pass.getArgument(), dependentDialectRegistrations);
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
emitPassOptionDecls(pass, os);
|
||||
emitPassStatisticDecls(pass, os);
|
||||
os << "};\n";
|
||||
|
@ -154,8 +157,8 @@ const char *const passRegistrationCode = R"(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline void register{0}Pass() {{
|
||||
::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
|
||||
return {3};
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
|
||||
return {1};
|
||||
});
|
||||
}
|
||||
)";
|
||||
|
@ -175,7 +178,6 @@ static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
|
|||
os << "#ifdef GEN_PASS_REGISTRATION\n";
|
||||
for (const Pass &pass : passes) {
|
||||
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
pass.getConstructor());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue