[mlir] Change the internal representation of FrozenRewritePatternList to use shared_ptr

This will allow for caching pattern lists across multiple pass instances, such as when multithreading. This is an extremely important invariant for PDL patterns, which are compiled at runtime when the FrozenRewritePatternList is built.

Differential Revision: https://reviews.llvm.org/D93146
This commit is contained in:
River Riddle 2020-12-14 12:32:21 -08:00
parent 6f271e921b
commit 6af2c4ca9b
2 changed files with 32 additions and 15 deletions

View File

@ -18,34 +18,52 @@ class PDLByteCode;
/// This class represents a frozen set of patterns that can be processed by a
/// pattern applicator. This class is designed to enable caching pattern lists
/// such that they need not be continuously recomputed.
/// such that they need not be continuously recomputed. Note that all copies of
/// this class share the same compiled pattern list, allowing for a reduction in
/// the number of duplicated patterns that need to be created.
class FrozenRewritePatternList {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternList();
FrozenRewritePatternList(OwningRewritePatternList &&patterns);
FrozenRewritePatternList(FrozenRewritePatternList &&patterns);
FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default;
FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default;
FrozenRewritePatternList &
operator=(const FrozenRewritePatternList &patterns) = default;
FrozenRewritePatternList &
operator=(FrozenRewritePatternList &&patterns) = default;
~FrozenRewritePatternList();
/// Return the native patterns held by this list.
iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
getNativePatterns() const {
const NativePatternListT &nativePatterns = impl->nativePatterns;
return llvm::make_pointee_range(nativePatterns);
}
/// Return the compiled PDL bytecode held by this list. Returns null if
/// there are no PDL patterns within the list.
const detail::PDLByteCode *getPDLByteCode() const {
return pdlByteCode.get();
return impl->pdlByteCode.get();
}
private:
/// The set of.
std::vector<std::unique_ptr<RewritePattern>> nativePatterns;
/// The internal implementation of the frozen pattern list.
struct Impl {
/// The set of native C++ rewrite patterns.
NativePatternListT nativePatterns;
/// The bytecode containing the compiled PDL patterns.
std::unique_ptr<detail::PDLByteCode> pdlByteCode;
/// The bytecode containing the compiled PDL patterns.
std::unique_ptr<detail::PDLByteCode> pdlByteCode;
};
/// A pointer to the internal pattern list. This uses a shared_ptr to avoid
/// the need to compile the same pattern list multiple times. For example,
/// during multi-threaded pass execution, all copies of a pass can share the
/// same pattern list.
std::shared_ptr<Impl> impl;
};
} // end namespace mlir

View File

@ -50,12 +50,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
// FrozenRewritePatternList
//===----------------------------------------------------------------------===//
FrozenRewritePatternList::FrozenRewritePatternList()
: impl(std::make_shared<Impl>()) {}
FrozenRewritePatternList::FrozenRewritePatternList(
OwningRewritePatternList &&patterns)
: nativePatterns(std::move(patterns.getNativePatterns())) {
PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
: impl(std::make_shared<Impl>()) {
impl->nativePatterns = std::move(patterns.getNativePatterns());
// Generate the bytecode for the PDL patterns if any were provided.
PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
ModuleOp pdlModule = pdlPatterns.getModule();
if (!pdlModule)
return;
@ -64,14 +68,9 @@ FrozenRewritePatternList::FrozenRewritePatternList(
"failed to lower PDL pattern module to the PDL Interpreter");
// Generate the pdl bytecode.
pdlByteCode = std::make_unique<detail::PDLByteCode>(
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
}
FrozenRewritePatternList::FrozenRewritePatternList(
FrozenRewritePatternList &&patterns)
: nativePatterns(std::move(patterns.nativePatterns)),
pdlByteCode(std::move(patterns.pdlByteCode)) {}
FrozenRewritePatternList::~FrozenRewritePatternList() {}