diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index 20f64cc274c1..588129e3f6ee 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -31,6 +31,10 @@ namespace mlir { class Pass; +class PassManager; + +/// A registry function that adds passes to the given pass manager. +using PassRegistryFunction = std::function; using PassAllocatorFunction = std::function; @@ -43,22 +47,15 @@ struct alignas(8) PassID { } }; -/// Structure to group information about a pass (argument to invoke via -/// mlir-opt, description, pass allocator and unique ID). -class PassInfo { +/// Structure to group information about a passes and pass pipelines (argument +/// to invoke via mlir-opt, description, pass pipeline builder). +class PassRegistryEntry { public: - /// PassInfo constructor should not be invoked directly, instead use - /// PassRegistration or registerPass. - PassInfo(StringRef arg, StringRef description, const PassID *passID, - PassAllocatorFunction allocator) - : arg(arg), description(description), allocator(allocator), - passID(passID) {} - - /// Returns an allocated instance of this pass. - Pass *createPass() const { - assert(allocator && - "Cannot call createPass on PassInfo without default allocator"); - return allocator(); + /// Adds this pass registry entry to the given pass manager. + void addToPipeline(PassManager &pm) const { + assert(builder && + "Cannot call addToPipeline on PassRegistryEntry without builder"); + builder(pm); } /// Returns the command line option that may be passed to 'mlir-opt' that will @@ -68,6 +65,11 @@ public: /// Returns a description for the pass, this never returns null. StringRef getPassDescription() const { return description; } +protected: + PassRegistryEntry(StringRef arg, StringRef description, + PassRegistryFunction builder) + : arg(arg), description(description), builder(builder) {} + private: // The argument with which to invoke the pass via mlir-opt. StringRef arg; @@ -75,20 +77,43 @@ private: // Description of the pass. StringRef description; - // Allocator to construct an instance of this pass. - PassAllocatorFunction allocator; + // Function to register this entry to a pass manager pipeline. + PassRegistryFunction builder; +}; +/// A structure to represent the information of a registered pass pipeline. +class PassPipelineInfo : public PassRegistryEntry { +public: + PassPipelineInfo(StringRef arg, StringRef description, + PassRegistryFunction builder) + : PassRegistryEntry(arg, description, builder) {} +}; + +/// A structure to represent the information for a derived pass class. +class PassInfo : public PassRegistryEntry { +public: + /// PassInfo constructor should not be invoked directly, instead use + /// PassRegistration or registerPass. + PassInfo(StringRef arg, StringRef description, const PassID *passID, + PassAllocatorFunction allocator); + +private: // Unique identifier for pass. const PassID *passID; }; -/// Register a specific dialect creation function with the system, typically -/// used through the PassRegistration template. +/// Register a specific dialect pipeline registry function with the system, +/// typically used through the PassPipelineRegistration template. +void registerPassPipeline(StringRef arg, StringRef description, + const PassRegistryFunction &function); + +/// Register a specific dialect pass allocator function with the system, +/// typically used through the PassRegistration template. void registerPass(StringRef arg, StringRef description, const PassID *passID, const PassAllocatorFunction &function); /// PassRegistration provides a global initializer that registers a Pass -/// allocation routine. +/// allocation routine for a concrete pass instance. /// /// Usage: /// @@ -97,12 +122,38 @@ void registerPass(StringRef arg, StringRef description, const PassID *passID, template struct PassRegistration { PassRegistration(StringRef arg, StringRef description) { registerPass(arg, description, PassID::getID(), - [&]() { return new ConcretePass(); }); + [] { return new ConcretePass(); }); } }; +/// PassPipelineRegistration provides a global initializer that registers a Pass +/// pipeline builder routine. +/// +/// Usage: +/// +/// // At namespace scope. +/// void pipelineBuilder(PassManager &pm) { +/// pm.addPass(new MyPass()); +/// pm.addPass(new MyOtherPass()); +/// } +/// +/// static PassPipelineRegistration Unused("unused", "Unused pass", +/// pipelineBuilder); +struct PassPipelineRegistration { + PassPipelineRegistration(StringRef arg, StringRef description, + PassRegistryFunction builder) { + registerPassPipeline(arg, description, builder); + } + + /// Constructor that accepts a pass allocator function instead of the standard + /// registry function. This is useful for registering specializations of + /// existing passes. + PassPipelineRegistration(StringRef arg, StringRef description, + PassAllocatorFunction allocator); +}; + /// Adds command line option for each registered pass. -struct PassNameParser : public llvm::cl::parser { +struct PassNameParser : public llvm::cl::parser { PassNameParser(llvm::cl::Option &opt); void printOptionInfo(const llvm::cl::Option &O, diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 1a3dd6ffff07..f0835da77d48 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -22,6 +22,7 @@ #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" @@ -161,20 +162,20 @@ static inline Error make_string_error(const llvm::Twine &message) { llvm::inconvertibleErrorCode()); } -// Given a list of PassInfo coming from a higher level, creates the passes to -// run as an owning vector and appends the extra required passes to lower to -// LLVMIR. Currently, these extra passes are: +// Given a list of PassRegistryEntry coming from a higher level, populates the +// given pass manager and appends the default set of required passes to lower to +// LLVMIR. +// Currently, these passes are: // - constant folding // - CSE // - canonicalization // - affine lowering -static void -getDefaultPasses(PassManager &manager, - const std::vector &mlirPassInfoList) { +static void getDefaultPasses( + PassManager &manager, + const std::vector &mlirPassRegistryList) { // Run each of the passes that were selected. - for (const auto *passInfo : mlirPassInfoList) { - manager.addPass(passInfo->createPass()); - } + for (const auto *passEntry : mlirPassRegistryList) + passEntry->addToPipeline(manager); // Append the extra passes for lowering to MLIR. manager.addPass(mlir::createConstantFoldPass()); diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index e90fb2217a22..b0927f4f550e 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -17,6 +17,7 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/ManagedStatic.h" @@ -26,13 +27,50 @@ using namespace mlir; static llvm::ManagedStatic> passRegistry; +/// Static mapping of all of the registered pass pipelines. +static llvm::ManagedStatic> + passPipelineRegistry; + +/// Utility to create a default registry function from a pass instance. +static PassRegistryFunction +buildDefaultRegistryFn(PassAllocatorFunction allocator) { + return [=](PassManager &pm) { pm.addPass(allocator()); }; +} + +//===----------------------------------------------------------------------===// +// PassPipelineInfo +//===----------------------------------------------------------------------===// + +/// Constructor that accepts a pass allocator function instead of the standard +/// registry function. This is useful for registering specializations of +/// existing passes. +PassPipelineRegistration::PassPipelineRegistration( + StringRef arg, StringRef description, PassAllocatorFunction allocator) { + registerPassPipeline(arg, description, buildDefaultRegistryFn(allocator)); +} + +void mlir::registerPassPipeline(StringRef arg, StringRef description, + const PassRegistryFunction &function) { + PassPipelineInfo pipelineInfo(arg, description, function); + bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; + assert(inserted && "Pass pipeline registered multiple times"); + (void)inserted; +} + +//===----------------------------------------------------------------------===// +// PassInfo +//===----------------------------------------------------------------------===// + +PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID, + PassAllocatorFunction allocator) + : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)), + passID(passID) {} + void mlir::registerPass(StringRef arg, StringRef description, const PassID *passID, const PassAllocatorFunction &function) { - bool inserted = passRegistry - ->insert(std::make_pair( - passID, PassInfo(arg, description, passID, function))) - .second; + PassInfo passInfo(arg, description, passID, function); + bool inserted = passRegistry->try_emplace(passID, passInfo).second; assert(inserted && "Pass registered multiple times"); (void)inserted; } @@ -45,12 +83,22 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) { return &it->getSecond(); } +//===----------------------------------------------------------------------===// +// PassNameParser +//===----------------------------------------------------------------------===// + PassNameParser::PassNameParser(llvm::cl::Option &opt) - : llvm::cl::parser(opt) { + : llvm::cl::parser(opt) { + /// Add the pass entries. for (const auto &kv : *passRegistry) { addLiteralOption(kv.second.getPassArgument(), &kv.second, kv.second.getPassDescription()); } + /// Add the pass pipeline entries. + for (const auto &kv : *passPipelineRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } } void PassNameParser::printOptionInfo(const llvm::cl::Option &O, @@ -62,5 +110,5 @@ void PassNameParser::printOptionInfo(const llvm::cl::Option &O, return VT1->Name.compare(VT2->Name); }); using llvm::cl::parser; - parser::printOptionInfo(O, GlobalWidth); + parser::printOptionInfo(O, GlobalWidth); } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 4a2b4e7489ff..96c85b237198 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "mlir/TensorFlow/ControlFlowOps.h" @@ -67,7 +68,7 @@ static cl::opt "expected-* lines on the corresponding line"), cl::init(false)); -static std::vector *passList; +static std::vector *passList; enum OptResult { OptSuccess, OptFailure }; @@ -128,8 +129,8 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { // TODO(riverriddle) Make sure that the verifer is run after each pass when it // is no longer run by default within the PassManager. PassManager pm; - for (const auto *passInfo : *passList) - pm.addPass(passInfo->createPass()); + for (const auto *passEntry : *passList) + passEntry->addToPipeline(pm); if (pm.run(module.get())) return OptFailure; @@ -364,8 +365,8 @@ int main(int argc, char **argv) { InitLLVM y(argc, argv); // Parse pass names in main to ensure static initialization completed. - llvm::cl::list passList( - "", llvm::cl::desc("Compiler passes to run")); + llvm::cl::list + passList("", llvm::cl::desc("Compiler passes to run")); ::passList = &passList; cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");