From fa51c5af5d5de25a7824a939e90734ae5ca5448d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 2 Jun 2021 12:06:32 -0700 Subject: [PATCH] [mlir] Resolve TODO and use the pass argument instead of the TypeID for registration This simplifies various pieces of code that interact with the pass registry, e.g. this removes the need to register passes to get accurate pass pipelines descriptions when generating crash reproducers. Differential Revision: https://reviews.llvm.org/D101880 --- mlir/include/mlir/Pass/Pass.h | 17 ++++------- mlir/include/mlir/Pass/PassRegistry.h | 2 +- mlir/lib/Pass/PassRegistry.cpp | 41 +++++++++++++++----------- mlir/test/lib/Pass/TestPassManager.cpp | 4 +++ 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 42df2dc8bd08..67c695467706 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -56,13 +56,12 @@ public: TypeID getTypeID() const { return passID; } /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(TypeID passID); - template static const PassInfo *lookupPassInfo() { - return lookupPassInfo(TypeID::get()); - } + static const PassInfo *lookupPassInfo(StringRef passArg); - /// Returns the pass info for this pass. - const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); } + /// Returns the pass info for this pass, or null if unknown. + const PassInfo *lookupPassInfo() const { + return lookupPassInfo(getArgument()); + } /// Returns the derived pass name. virtual StringRef getName() const = 0; @@ -76,11 +75,7 @@ public: /// Returns the command line argument used when registering this pass. Return /// an empty string if one does not exist. - virtual StringRef getArgument() const { - if (const PassInfo *passInfo = lookupPassInfo()) - return passInfo->getPassArgument(); - return ""; - } + virtual StringRef getArgument() const { return ""; } /// Returns the name of the operation that this pass operates on, or None if /// this is a generic OperationPass. diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index 8def0f31ad15..d03aaf8dfd25 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -108,7 +108,7 @@ class PassInfo : public PassRegistryEntry { public: /// PassInfo constructor should not be invoked directly, instead use /// PassRegistration or registerPass. - PassInfo(StringRef arg, StringRef description, TypeID passID, + PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator); }; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index e53113eb968e..2c690a2659ac 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -19,7 +19,11 @@ using namespace mlir; using namespace detail; /// Static mapping of all of the registered passes. -static llvm::ManagedStatic> passRegistry; +static llvm::ManagedStatic> passRegistry; + +/// A mapping of the above pass registry entries to the corresponding TypeID +/// of the pass that they generate. +static llvm::ManagedStatic> passRegistryTypeIDs; /// Static mapping of all of the registered pass pipelines. static llvm::ManagedStatic> @@ -94,7 +98,7 @@ void mlir::registerPassPipeline( // PassInfo //===----------------------------------------------------------------------===// -PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID, +PassInfo::PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator) : PassRegistryEntry( arg, description, buildDefaultRegistryFn(allocator), @@ -105,18 +109,23 @@ PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID, void mlir::registerPass(StringRef arg, StringRef description, const PassAllocatorFunction &function) { - // TODO: We should use the 'arg' as the lookup key instead of the pass id. - TypeID passID = function()->getTypeID(); - PassInfo passInfo(arg, description, passID, function); - passRegistry->try_emplace(passID, passInfo); + PassInfo passInfo(arg, description, function); + passRegistry->try_emplace(arg, passInfo); + + // Verify that the registered pass has the same ID as any registered to this + // arg before it. + TypeID entryTypeID = function()->getTypeID(); + auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first; + if (it->second != entryTypeID) { + llvm_unreachable("pass allocator creates a different pass than previously " + "registered"); + } } -/// Returns the pass info for the specified pass class or null if unknown. -const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) { - auto it = passRegistry->find(passID); - if (it == passRegistry->end()) - return nullptr; - return &it->getSecond(); +/// Returns the pass info for the specified pass argument or null if unknown. +const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { + auto it = passRegistry->find(passArg); + return it == passRegistry->end() ? nullptr : &it->second; } //===----------------------------------------------------------------------===// @@ -433,12 +442,8 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element, } // If not, then this must be a specific pass name. - for (auto &passIt : *passRegistry) { - if (passIt.second.getPassArgument() == element.name) { - element.registryEntry = &passIt.second; - return success(); - } - } + if ((element.registryEntry = Pass::lookupPassInfo(element.name))) + return success(); // Emit an error for the unknown pass. auto *rawLoc = element.name.data(); diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 937a5c2317c2..6e5a5b9de8ba 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -16,9 +16,11 @@ namespace { struct TestModulePass : public PassWrapper> { void runOnOperation() final {} + StringRef getArgument() const final { return "test-module-pass"; } }; struct TestFunctionPass : public PassWrapper { void runOnFunction() final {} + StringRef getArgument() const final { return "test-function-pass"; } }; class TestOptionsPass : public PassWrapper { public: @@ -41,6 +43,7 @@ public: } void runOnFunction() final {} + StringRef getArgument() const final { return "test-options-pass"; } ListOption listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Example list option")}; @@ -56,6 +59,7 @@ public: class TestCrashRecoveryPass : public PassWrapper> { void runOnOperation() final { abort(); } + StringRef getArgument() const final { return "test-pass-crash"; } }; /// A test pass that always fails to enable testing the failure recovery