forked from OSchip/llvm-project
Add support for registering pass pipelines to the PassRegistry. This is done by providing a static registration facility PassPipelineRegistration that works similarly to PassRegistration except for it also takes a function that will add necessary passes to a provided PassManager.
void pipelineBuilder(PassManager &pm) { pm.addPass(new MyPass()); pm.addPass(new MyOtherPass()); } static PassPipelineRegistration Unused("unused", "Unused pass", pipelineBuilder); This is also useful for registering specializations of existing passes: Pass *createFooPass10() { return new FooPass(10); } static PassPipelineRegistration Unused("unused", "Unused pass", createFooPass10); PiperOrigin-RevId: 235996282
This commit is contained in:
parent
e31c23853b
commit
091ff3dc3f
|
@ -31,6 +31,10 @@
|
|||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
class PassManager;
|
||||
|
||||
/// A registry function that adds passes to the given pass manager.
|
||||
using PassRegistryFunction = std::function<void(PassManager &)>;
|
||||
|
||||
using PassAllocatorFunction = std::function<Pass *()>;
|
||||
|
||||
|
@ -43,22 +47,15 @@ struct alignas(8) PassID {
|
|||
}
|
||||
};
|
||||
|
||||
/// Structure to group information about a pass (argument to invoke via
|
||||
/// mlir-opt, description, pass allocator and unique ID).
|
||||
class PassInfo {
|
||||
/// Structure to group information about a passes and pass pipelines (argument
|
||||
/// to invoke via mlir-opt, description, pass pipeline builder).
|
||||
class PassRegistryEntry {
|
||||
public:
|
||||
/// PassInfo constructor should not be invoked directly, instead use
|
||||
/// PassRegistration or registerPass.
|
||||
PassInfo(StringRef arg, StringRef description, const PassID *passID,
|
||||
PassAllocatorFunction allocator)
|
||||
: arg(arg), description(description), allocator(allocator),
|
||||
passID(passID) {}
|
||||
|
||||
/// Returns an allocated instance of this pass.
|
||||
Pass *createPass() const {
|
||||
assert(allocator &&
|
||||
"Cannot call createPass on PassInfo without default allocator");
|
||||
return allocator();
|
||||
/// Adds this pass registry entry to the given pass manager.
|
||||
void addToPipeline(PassManager &pm) const {
|
||||
assert(builder &&
|
||||
"Cannot call addToPipeline on PassRegistryEntry without builder");
|
||||
builder(pm);
|
||||
}
|
||||
|
||||
/// Returns the command line option that may be passed to 'mlir-opt' that will
|
||||
|
@ -68,6 +65,11 @@ public:
|
|||
/// Returns a description for the pass, this never returns null.
|
||||
StringRef getPassDescription() const { return description; }
|
||||
|
||||
protected:
|
||||
PassRegistryEntry(StringRef arg, StringRef description,
|
||||
PassRegistryFunction builder)
|
||||
: arg(arg), description(description), builder(builder) {}
|
||||
|
||||
private:
|
||||
// The argument with which to invoke the pass via mlir-opt.
|
||||
StringRef arg;
|
||||
|
@ -75,20 +77,43 @@ private:
|
|||
// Description of the pass.
|
||||
StringRef description;
|
||||
|
||||
// Allocator to construct an instance of this pass.
|
||||
PassAllocatorFunction allocator;
|
||||
// Function to register this entry to a pass manager pipeline.
|
||||
PassRegistryFunction builder;
|
||||
};
|
||||
|
||||
/// A structure to represent the information of a registered pass pipeline.
|
||||
class PassPipelineInfo : public PassRegistryEntry {
|
||||
public:
|
||||
PassPipelineInfo(StringRef arg, StringRef description,
|
||||
PassRegistryFunction builder)
|
||||
: PassRegistryEntry(arg, description, builder) {}
|
||||
};
|
||||
|
||||
/// A structure to represent the information for a derived pass class.
|
||||
class PassInfo : public PassRegistryEntry {
|
||||
public:
|
||||
/// PassInfo constructor should not be invoked directly, instead use
|
||||
/// PassRegistration or registerPass.
|
||||
PassInfo(StringRef arg, StringRef description, const PassID *passID,
|
||||
PassAllocatorFunction allocator);
|
||||
|
||||
private:
|
||||
// Unique identifier for pass.
|
||||
const PassID *passID;
|
||||
};
|
||||
|
||||
/// Register a specific dialect creation function with the system, typically
|
||||
/// used through the PassRegistration template.
|
||||
/// Register a specific dialect pipeline registry function with the system,
|
||||
/// typically used through the PassPipelineRegistration template.
|
||||
void registerPassPipeline(StringRef arg, StringRef description,
|
||||
const PassRegistryFunction &function);
|
||||
|
||||
/// Register a specific dialect pass allocator function with the system,
|
||||
/// typically used through the PassRegistration template.
|
||||
void registerPass(StringRef arg, StringRef description, const PassID *passID,
|
||||
const PassAllocatorFunction &function);
|
||||
|
||||
/// PassRegistration provides a global initializer that registers a Pass
|
||||
/// allocation routine.
|
||||
/// allocation routine for a concrete pass instance.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
|
@ -97,12 +122,38 @@ void registerPass(StringRef arg, StringRef description, const PassID *passID,
|
|||
template <typename ConcretePass> struct PassRegistration {
|
||||
PassRegistration(StringRef arg, StringRef description) {
|
||||
registerPass(arg, description, PassID::getID<ConcretePass>(),
|
||||
[&]() { return new ConcretePass(); });
|
||||
[] { return new ConcretePass(); });
|
||||
}
|
||||
};
|
||||
|
||||
/// PassPipelineRegistration provides a global initializer that registers a Pass
|
||||
/// pipeline builder routine.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// // At namespace scope.
|
||||
/// void pipelineBuilder(PassManager &pm) {
|
||||
/// pm.addPass(new MyPass());
|
||||
/// pm.addPass(new MyOtherPass());
|
||||
/// }
|
||||
///
|
||||
/// static PassPipelineRegistration Unused("unused", "Unused pass",
|
||||
/// pipelineBuilder);
|
||||
struct PassPipelineRegistration {
|
||||
PassPipelineRegistration(StringRef arg, StringRef description,
|
||||
PassRegistryFunction builder) {
|
||||
registerPassPipeline(arg, description, builder);
|
||||
}
|
||||
|
||||
/// Constructor that accepts a pass allocator function instead of the standard
|
||||
/// registry function. This is useful for registering specializations of
|
||||
/// existing passes.
|
||||
PassPipelineRegistration(StringRef arg, StringRef description,
|
||||
PassAllocatorFunction allocator);
|
||||
};
|
||||
|
||||
/// Adds command line option for each registered pass.
|
||||
struct PassNameParser : public llvm::cl::parser<const PassInfo *> {
|
||||
struct PassNameParser : public llvm::cl::parser<const PassRegistryEntry *> {
|
||||
PassNameParser(llvm::cl::Option &opt);
|
||||
|
||||
void printOptionInfo(const llvm::cl::Option &O,
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -161,20 +162,20 @@ static inline Error make_string_error(const llvm::Twine &message) {
|
|||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Given a list of PassInfo coming from a higher level, creates the passes to
|
||||
// run as an owning vector and appends the extra required passes to lower to
|
||||
// LLVMIR. Currently, these extra passes are:
|
||||
// Given a list of PassRegistryEntry coming from a higher level, populates the
|
||||
// given pass manager and appends the default set of required passes to lower to
|
||||
// LLVMIR.
|
||||
// Currently, these passes are:
|
||||
// - constant folding
|
||||
// - CSE
|
||||
// - canonicalization
|
||||
// - affine lowering
|
||||
static void
|
||||
getDefaultPasses(PassManager &manager,
|
||||
const std::vector<const mlir::PassInfo *> &mlirPassInfoList) {
|
||||
static void getDefaultPasses(
|
||||
PassManager &manager,
|
||||
const std::vector<const mlir::PassRegistryEntry *> &mlirPassRegistryList) {
|
||||
// Run each of the passes that were selected.
|
||||
for (const auto *passInfo : mlirPassInfoList) {
|
||||
manager.addPass(passInfo->createPass());
|
||||
}
|
||||
for (const auto *passEntry : mlirPassRegistryList)
|
||||
passEntry->addToPipeline(manager);
|
||||
|
||||
// Append the extra passes for lowering to MLIR.
|
||||
manager.addPass(mlir::createConstantFoldPass());
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
|
||||
|
@ -26,13 +27,50 @@ using namespace mlir;
|
|||
static llvm::ManagedStatic<llvm::DenseMap<const PassID *, PassInfo>>
|
||||
passRegistry;
|
||||
|
||||
/// Static mapping of all of the registered pass pipelines.
|
||||
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
|
||||
passPipelineRegistry;
|
||||
|
||||
/// Utility to create a default registry function from a pass instance.
|
||||
static PassRegistryFunction
|
||||
buildDefaultRegistryFn(PassAllocatorFunction allocator) {
|
||||
return [=](PassManager &pm) { pm.addPass(allocator()); };
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassPipelineInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Constructor that accepts a pass allocator function instead of the standard
|
||||
/// registry function. This is useful for registering specializations of
|
||||
/// existing passes.
|
||||
PassPipelineRegistration::PassPipelineRegistration(
|
||||
StringRef arg, StringRef description, PassAllocatorFunction allocator) {
|
||||
registerPassPipeline(arg, description, buildDefaultRegistryFn(allocator));
|
||||
}
|
||||
|
||||
void mlir::registerPassPipeline(StringRef arg, StringRef description,
|
||||
const PassRegistryFunction &function) {
|
||||
PassPipelineInfo pipelineInfo(arg, description, function);
|
||||
bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
|
||||
assert(inserted && "Pass pipeline registered multiple times");
|
||||
(void)inserted;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
|
||||
PassAllocatorFunction allocator)
|
||||
: PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)),
|
||||
passID(passID) {}
|
||||
|
||||
void mlir::registerPass(StringRef arg, StringRef description,
|
||||
const PassID *passID,
|
||||
const PassAllocatorFunction &function) {
|
||||
bool inserted = passRegistry
|
||||
->insert(std::make_pair(
|
||||
passID, PassInfo(arg, description, passID, function)))
|
||||
.second;
|
||||
PassInfo passInfo(arg, description, passID, function);
|
||||
bool inserted = passRegistry->try_emplace(passID, passInfo).second;
|
||||
assert(inserted && "Pass registered multiple times");
|
||||
(void)inserted;
|
||||
}
|
||||
|
@ -45,12 +83,22 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
|
|||
return &it->getSecond();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassNameParser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PassNameParser::PassNameParser(llvm::cl::Option &opt)
|
||||
: llvm::cl::parser<const PassInfo *>(opt) {
|
||||
: llvm::cl::parser<const PassRegistryEntry *>(opt) {
|
||||
/// Add the pass entries.
|
||||
for (const auto &kv : *passRegistry) {
|
||||
addLiteralOption(kv.second.getPassArgument(), &kv.second,
|
||||
kv.second.getPassDescription());
|
||||
}
|
||||
/// Add the pass pipeline entries.
|
||||
for (const auto &kv : *passPipelineRegistry) {
|
||||
addLiteralOption(kv.second.getPassArgument(), &kv.second,
|
||||
kv.second.getPassDescription());
|
||||
}
|
||||
}
|
||||
|
||||
void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
|
||||
|
@ -62,5 +110,5 @@ void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
|
|||
return VT1->Name.compare(VT2->Name);
|
||||
});
|
||||
using llvm::cl::parser;
|
||||
parser<const PassInfo *>::printOptionInfo(O, GlobalWidth);
|
||||
parser<const PassRegistryEntry *>::printOptionInfo(O, GlobalWidth);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/TensorFlow/ControlFlowOps.h"
|
||||
|
@ -67,7 +68,7 @@ static cl::opt<bool>
|
|||
"expected-* lines on the corresponding line"),
|
||||
cl::init(false));
|
||||
|
||||
static std::vector<const mlir::PassInfo *> *passList;
|
||||
static std::vector<const mlir::PassRegistryEntry *> *passList;
|
||||
|
||||
enum OptResult { OptSuccess, OptFailure };
|
||||
|
||||
|
@ -128,8 +129,8 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
|
|||
// TODO(riverriddle) Make sure that the verifer is run after each pass when it
|
||||
// is no longer run by default within the PassManager.
|
||||
PassManager pm;
|
||||
for (const auto *passInfo : *passList)
|
||||
pm.addPass(passInfo->createPass());
|
||||
for (const auto *passEntry : *passList)
|
||||
passEntry->addToPipeline(pm);
|
||||
if (pm.run(module.get()))
|
||||
return OptFailure;
|
||||
|
||||
|
@ -364,8 +365,8 @@ int main(int argc, char **argv) {
|
|||
InitLLVM y(argc, argv);
|
||||
|
||||
// Parse pass names in main to ensure static initialization completed.
|
||||
llvm::cl::list<const mlir::PassInfo *, bool, mlir::PassNameParser> passList(
|
||||
"", llvm::cl::desc("Compiler passes to run"));
|
||||
llvm::cl::list<const mlir::PassRegistryEntry *, bool, mlir::PassNameParser>
|
||||
passList("", llvm::cl::desc("Compiler passes to run"));
|
||||
::passList = &passList;
|
||||
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
|
||||
|
||||
|
|
Loading…
Reference in New Issue