[mlir] Resolve TODO and use the pass argument instead of the TypeID for registration

This simplifies various pieces of code that interact with the pass registry, e.g. this removes the need to register passes to get accurate pass pipelines descriptions when generating crash reproducers.

Differential Revision: https://reviews.llvm.org/D101880
This commit is contained in:
River Riddle 2021-06-02 12:06:32 -07:00
parent 8beaca8c14
commit fa51c5af5d
4 changed files with 34 additions and 30 deletions

View File

@ -56,13 +56,12 @@ public:
TypeID getTypeID() const { return passID; } TypeID getTypeID() const { return passID; }
/// Returns the pass info for the specified pass class or null if unknown. /// Returns the pass info for the specified pass class or null if unknown.
static const PassInfo *lookupPassInfo(TypeID passID); static const PassInfo *lookupPassInfo(StringRef passArg);
template <typename PassT> static const PassInfo *lookupPassInfo() {
return lookupPassInfo(TypeID::get<PassT>());
}
/// Returns the pass info for this pass. /// Returns the pass info for this pass, or null if unknown.
const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); } const PassInfo *lookupPassInfo() const {
return lookupPassInfo(getArgument());
}
/// Returns the derived pass name. /// Returns the derived pass name.
virtual StringRef getName() const = 0; virtual StringRef getName() const = 0;
@ -76,11 +75,7 @@ public:
/// Returns the command line argument used when registering this pass. Return /// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist. /// an empty string if one does not exist.
virtual StringRef getArgument() const { virtual StringRef getArgument() const { return ""; }
if (const PassInfo *passInfo = lookupPassInfo())
return passInfo->getPassArgument();
return "";
}
/// Returns the name of the operation that this pass operates on, or None if /// Returns the name of the operation that this pass operates on, or None if
/// this is a generic OperationPass. /// this is a generic OperationPass.

View File

@ -108,7 +108,7 @@ class PassInfo : public PassRegistryEntry {
public: public:
/// PassInfo constructor should not be invoked directly, instead use /// PassInfo constructor should not be invoked directly, instead use
/// PassRegistration or registerPass. /// PassRegistration or registerPass.
PassInfo(StringRef arg, StringRef description, TypeID passID, PassInfo(StringRef arg, StringRef description,
const PassAllocatorFunction &allocator); const PassAllocatorFunction &allocator);
}; };

View File

@ -19,7 +19,11 @@ using namespace mlir;
using namespace detail; using namespace detail;
/// Static mapping of all of the registered passes. /// Static mapping of all of the registered passes.
static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry; static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
/// A mapping of the above pass registry entries to the corresponding TypeID
/// of the pass that they generate.
static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
/// Static mapping of all of the registered pass pipelines. /// Static mapping of all of the registered pass pipelines.
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
@ -94,7 +98,7 @@ void mlir::registerPassPipeline(
// PassInfo // PassInfo
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID, PassInfo::PassInfo(StringRef arg, StringRef description,
const PassAllocatorFunction &allocator) const PassAllocatorFunction &allocator)
: PassRegistryEntry( : PassRegistryEntry(
arg, description, buildDefaultRegistryFn(allocator), arg, description, buildDefaultRegistryFn(allocator),
@ -105,18 +109,23 @@ PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
void mlir::registerPass(StringRef arg, StringRef description, void mlir::registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function) { const PassAllocatorFunction &function) {
// TODO: We should use the 'arg' as the lookup key instead of the pass id. PassInfo passInfo(arg, description, function);
TypeID passID = function()->getTypeID(); passRegistry->try_emplace(arg, passInfo);
PassInfo passInfo(arg, description, passID, function);
passRegistry->try_emplace(passID, passInfo); // Verify that the registered pass has the same ID as any registered to this
// arg before it.
TypeID entryTypeID = function()->getTypeID();
auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
if (it->second != entryTypeID) {
llvm_unreachable("pass allocator creates a different pass than previously "
"registered");
}
} }
/// Returns the pass info for the specified pass class or null if unknown. /// Returns the pass info for the specified pass argument or null if unknown.
const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) { const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
auto it = passRegistry->find(passID); auto it = passRegistry->find(passArg);
if (it == passRegistry->end()) return it == passRegistry->end() ? nullptr : &it->second;
return nullptr;
return &it->getSecond();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -433,12 +442,8 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
} }
// If not, then this must be a specific pass name. // If not, then this must be a specific pass name.
for (auto &passIt : *passRegistry) { if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
if (passIt.second.getPassArgument() == element.name) { return success();
element.registryEntry = &passIt.second;
return success();
}
}
// Emit an error for the unknown pass. // Emit an error for the unknown pass.
auto *rawLoc = element.name.data(); auto *rawLoc = element.name.data();

View File

@ -16,9 +16,11 @@ namespace {
struct TestModulePass struct TestModulePass
: public PassWrapper<TestModulePass, OperationPass<ModuleOp>> { : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
void runOnOperation() final {} void runOnOperation() final {}
StringRef getArgument() const final { return "test-module-pass"; }
}; };
struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> { struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
void runOnFunction() final {} void runOnFunction() final {}
StringRef getArgument() const final { return "test-function-pass"; }
}; };
class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> { class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
public: public:
@ -41,6 +43,7 @@ public:
} }
void runOnFunction() final {} void runOnFunction() final {}
StringRef getArgument() const final { return "test-options-pass"; }
ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated, ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Example list option")}; llvm::cl::desc("Example list option")};
@ -56,6 +59,7 @@ public:
class TestCrashRecoveryPass class TestCrashRecoveryPass
: public PassWrapper<TestCrashRecoveryPass, OperationPass<>> { : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
void runOnOperation() final { abort(); } void runOnOperation() final { abort(); }
StringRef getArgument() const final { return "test-pass-crash"; }
}; };
/// A test pass that always fails to enable testing the failure recovery /// A test pass that always fails to enable testing the failure recovery