[mlir][Pattern] Add better support for using interfaces/traits to match root operations in rewrite patterns

To match an interface or trait, users currently have to use the `MatchAny` tag. This tag can be quite problematic for compile time for things like the canonicalizer, as the `MatchAny` patterns may get applied to  *every* operation. This revision adds better support by bucketing interface/trait patterns based on which registered operations have them registered. This means that moving forward we will only attempt to match these patterns to operations that have this interface registered. Two simplify defining patterns that match traits and interfaces, two new utility classes have been added: OpTraitRewritePattern and OpInterfaceRewritePattern.

Differential Revision: https://reviews.llvm.org/D98986
This commit is contained in:
River Riddle 2021-03-23 13:44:14 -07:00
parent 782c534117
commit 76f3c2f3f3
33 changed files with 462 additions and 254 deletions

View File

@ -697,7 +697,7 @@ static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); }
template <typename FltOp, typename CpxOp>
struct UndoComplexPattern : public mlir::RewritePattern {
UndoComplexPattern(mlir::MLIRContext *ctx)
: mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {}
: mlir::RewritePattern("fir.insert_value", 2, ctx) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,

View File

@ -30,12 +30,12 @@ namespace linalg {
// or in an externally linked library.
// This is a generic entry point for all LinalgOp, except for CopyOp and
// IndexedGenericOp, for which omre specialized patterns are provided.
class LinalgOpToLibraryCallRewrite : public RewritePattern {
class LinalgOpToLibraryCallRewrite
: public OpInterfaceRewritePattern<LinalgOp> {
public:
LinalgOpToLibraryCallRewrite()
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override;
};

View File

@ -60,7 +60,8 @@ void enqueue(RewritePatternSet &patternList, OptionsType options,
if (!opName.empty())
patternList.add<PatternType>(opName, patternList.getContext(), options, m);
else
patternList.add<PatternType>(m.addOpFilter<OpType>(), options);
patternList.add<PatternType>(patternList.getContext(),
m.addOpFilter<OpType>(), options);
}
/// Promotion transformation enqueues a particular stage-1 pattern for

View File

@ -452,7 +452,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgBaseTilingPattern(
LinalgTilingOptions options,
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
@ -644,7 +644,8 @@ struct LinalgVectorizationOptions {};
struct LinalgBaseVectorizationPattern : public RewritePattern {
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBaseVectorizationPattern(LinalgTransformationFilter filter,
LinalgBaseVectorizationPattern(MLIRContext *context,
LinalgTransformationFilter filter,
PatternBenefit benefit = 1);
/// Name-based constructor with an optional `filter`.
LinalgBaseVectorizationPattern(
@ -663,10 +664,10 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
/// These constructors are available to anyone.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgVectorizationPattern(
LinalgTransformationFilter filter,
MLIRContext *context, LinalgTransformationFilter filter,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(filter, benefit) {}
: LinalgBaseVectorizationPattern(context, filter, benefit) {}
/// Name-based constructor with an optional `filter`.
LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
@ -702,8 +703,8 @@ template <typename OpType, typename = std::enable_if_t<
void insertVectorizationPatternImpl(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.add<linalg::LinalgVectorizationPattern>(f.addOpFilter<OpType>(),
options);
patternList.add<linalg::LinalgVectorizationPattern>(
patternList.getContext(), f.addOpFilter<OpType>(), options);
}
/// Variadic helper function to insert vectorization patterns for C++ ops.
@ -737,7 +738,7 @@ struct LinalgLoweringPattern : public RewritePattern {
MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
: RewritePattern(OpTy::getOperationName(), benefit, context),
filter(filter), loweringType(loweringType),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}

View File

@ -123,7 +123,8 @@ struct UnrollVectorOptions {
struct UnrollVectorPattern : public RewritePattern {
using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
options(options) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (options.filterConstraint && failed(options.filterConstraint(op)))
@ -216,7 +217,7 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(filter) {}
/// Performs the rewrite.

View File

@ -1516,6 +1516,13 @@ public:
#endif
return false;
}
/// Provide `classof` support for other OpBase derived classes, such as
/// Interfaces.
template <typename T>
static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
classof(const T *op) {
return classof(const_cast<T *>(op)->getOperation());
}
/// Expose the type we are instantiated on to template machinery that may want
/// to introspect traits on this operation.

View File

@ -142,12 +142,20 @@ public:
return interfaceMap.lookup<T>();
}
/// Returns true if this operation has the given interface registered to it.
bool hasInterface(TypeID interfaceID) const {
return interfaceMap.contains(interfaceID);
}
/// Returns true if the operation has a particular trait.
template <template <typename T> class Trait>
bool hasTrait() const {
return hasTraitFn(TypeID::get<Trait>());
}
/// Returns true if the operation has a particular trait.
bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
/// Look up the specified operation in the specified MLIRContext and return a
/// pointer to it if present. Otherwise, return a null pointer.
static const AbstractOperation *lookup(StringRef opName,

View File

@ -68,6 +68,19 @@ private:
/// used to interface with the metadata of a pattern, such as the benefit or
/// root operation.
class Pattern {
/// This enum represents the kind of value used to select the root operations
/// that match this pattern.
enum class RootKind {
/// The pattern root matches "any" operation.
Any,
/// The pattern root is matched using a concrete operation name.
OperationName,
/// The pattern root is matched using an interface ID.
InterfaceID,
/// The patter root is matched using a trait ID.
TraitID
};
public:
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
@ -75,7 +88,29 @@ public:
/// Return the root node that this pattern matches. Patterns that can match
/// multiple root types return None.
Optional<OperationName> getRootKind() const { return rootKind; }
Optional<OperationName> getRootKind() const {
if (rootKind == RootKind::OperationName)
return OperationName::getFromOpaquePointer(rootValue);
return llvm::None;
}
/// Return the interface ID used to match the root operation of this pattern.
/// If the pattern does not use an interface ID for deciding the root match,
/// this returns None.
Optional<TypeID> getRootInterfaceID() const {
if (rootKind == RootKind::InterfaceID)
return TypeID::getFromOpaquePointer(rootValue);
return llvm::None;
}
/// Return the trait ID used to match the root operation of this pattern.
/// If the pattern does not use a trait ID for deciding the root match, this
/// returns None.
Optional<TypeID> getRootTraitID() const {
if (rootKind == RootKind::TraitID)
return TypeID::getFromOpaquePointer(rootValue);
return llvm::None;
}
/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
@ -88,56 +123,85 @@ public:
/// i.e. this pattern may generate IR that also matches this pattern, but is
/// known to bound the recursion. This signals to a rewrite driver that it is
/// safe to apply this pattern recursively to generated IR.
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
bool hasBoundedRewriteRecursion() const {
return contextAndHasBoundedRecursion.getInt();
}
/// Return the MLIRContext used to create this pattern.
MLIRContext *getContext() const {
return contextAndHasBoundedRecursion.getPointer();
}
protected:
/// This class acts as a special tag that makes the desire to match "any"
/// operation type explicit. This helps to avoid unnecessary usages of this
/// feature, and ensures that the user is making a conscious decision.
struct MatchAnyOpTypeTag {};
/// This class acts as a special tag that makes the desire to match any
/// operation that implements a given interface explicit. This helps to avoid
/// unnecessary usages of this feature, and ensures that the user is making a
/// conscious decision.
struct MatchInterfaceOpTypeTag {};
/// This class acts as a special tag that makes the desire to match any
/// operation that implements a given trait explicit. This helps to avoid
/// unnecessary usages of this feature, and ensures that the user is making a
/// conscious decision.
struct MatchTraitOpTypeTag {};
/// Construct a pattern with a certain benefit that matches the operation
/// with the given root name.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
/// Construct a pattern with a certain benefit that matches any operation
/// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
/// Construct a pattern with a certain benefit that matches the operation with
/// the given root name. `generatedNames` contains the names of operations
/// that may be generated during a successful rewrite.
Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
ArrayRef<StringRef> generatedNames = {});
/// Construct a pattern that may match any operation type. `generatedNames`
/// contains the names of operations that may be generated during a successful
/// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag);
Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
ArrayRef<StringRef> generatedNames = {});
/// Construct a pattern that may match any operation that implements the
/// interface defined by the provided `interfaceID`. `generatedNames` contains
/// the names of operations that may be generated during a successful rewrite.
/// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
/// interface" behavior is what the user actually desired,
/// `MatchInterfaceOpTypeTag()` should always be supplied here.
Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
PatternBenefit benefit, MLIRContext *context,
ArrayRef<StringRef> generatedNames = {});
/// Construct a pattern that may match any operation that implements the
/// trait defined by the provided `traitID`. `generatedNames` contains the
/// names of operations that may be generated during a successful rewrite.
/// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
/// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
/// always be supplied here.
Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
/// Set the flag detailing if this pattern has bounded rewrite recursion or
/// not.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
hasBoundedRecursion = hasBoundedRecursionArg;
contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
}
private:
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
Pattern(const void *rootValue, RootKind rootKind,
ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context);
/// The root operation of the pattern. If the pattern matches a specific
/// operation, this contains the name of that operation. Contains None
/// otherwise.
Optional<OperationName> rootKind;
/// The value used to match the root operation of the pattern.
const void *rootValue;
RootKind rootKind;
/// The expected benefit of matching this pattern.
const PatternBenefit benefit;
/// A boolean flag of whether this pattern has bounded recursion or not.
bool hasBoundedRecursion = false;
/// The context this pattern was created from, and a boolean flag indicating
/// whether this pattern has bounded recursion or not.
llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
};
//===----------------------------------------------------------------------===//
@ -188,15 +252,13 @@ protected:
virtual void anchor();
};
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
namespace detail {
/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
/// allows for matching and rewriting against an instance of a derived operation
/// class or Interface.
template <typename SourceOp>
struct OpRewritePattern : public RewritePattern {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
using RewritePattern::RewritePattern;
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
@ -227,6 +289,43 @@ struct OpRewritePattern : public RewritePattern {
return failure();
}
};
} // namespace detail
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
SourceOp::getOperationName(), benefit, context) {}
};
/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of an operation interface instead
/// of a raw Operation.
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
benefit, context) {}
};
/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against instances of an operation that possess a
/// given trait.
template <template <typename> class TraitType>
class OpTraitRewritePattern : public RewritePattern {
public:
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
benefit, context) {}
};
//===----------------------------------------------------------------------===//
// PDLPatternModule

View File

@ -25,6 +25,10 @@ class FrozenRewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
/// A map of operation specific native patterns.
using OpSpecificNativePatternListT =
DenseMap<OperationName, std::vector<RewritePattern *>>;
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternSet();
FrozenRewritePatternSet(RewritePatternSet &&patterns);
@ -36,11 +40,16 @@ public:
operator=(FrozenRewritePatternSet &&patterns) = default;
~FrozenRewritePatternSet();
/// Return the native patterns held by this list.
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
return impl->nativeOpSpecificPatternMap;
}
/// Return the "match any" native patterns held by this list.
iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
getNativePatterns() const {
const NativePatternListT &nativePatterns = impl->nativePatterns;
return llvm::make_pointee_range(nativePatterns);
getMatchAnyOpNativePatterns() const {
const NativePatternListT &nativeList = impl->nativeAnyOpPatterns;
return llvm::make_pointee_range(nativeList);
}
/// Return the compiled PDL bytecode held by this list. Returns null if
@ -52,8 +61,17 @@ public:
private:
/// The internal implementation of the frozen pattern list.
struct Impl {
/// The set of native C++ rewrite patterns.
NativePatternListT nativePatterns;
/// The set of native C++ rewrite patterns that are matched to specific
/// operation kinds.
OpSpecificNativePatternListT nativeOpSpecificPatternMap;
/// The full op-specific native rewrite list. This allows for the map above
/// to contain duplicate patterns, e.g. for interfaces and traits.
NativePatternListT nativeOpSpecificPatternList;
/// The set of native C++ rewrite patterns that are matched to "any"
/// operation.
NativePatternListT nativeAnyOpPatterns;
/// The bytecode containing the compiled PDL patterns.
std::unique_ptr<detail::PDLByteCode> pdlByteCode;

View File

@ -183,6 +183,9 @@ public:
return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
}
/// Returns true if the interface map contains an interface for the given id.
bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
private:
/// Compare two TypeID instances by comparing the underlying pointer.
static bool compare(TypeID lhs, TypeID rhs) {

View File

@ -351,20 +351,12 @@ protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern that matches an operation with the given
/// root name. This constructor allows for providing a type converter to use
/// within the pattern.
ConversionPattern(StringRef rootName, PatternBenefit benefit,
TypeConverter &typeConverter, MLIRContext *ctx)
: RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
/// Construct a conversion pattern that matches any operation type. This
/// constructor allows for providing a type converter to use within the
/// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
MatchAnyOpTypeTag tag)
: RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
/// Construct a conversion pattern with the given converter, and forward the
/// remaining arguments to RewritePattern.
template <typename... Args>
ConversionPattern(TypeConverter &typeConverter, Args &&... args)
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
protected:
/// An optional type converter for use by this pattern.
@ -374,17 +366,13 @@ private:
using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
namespace detail {
/// OpOrInterfaceConversionPatternBase is a wrapper around ConversionPattern
/// that allows for matching and rewriting against an instance of a derived
/// operation class or an Interface as opposed to a raw Operation.
template <typename SourceOp>
struct OpConversionPattern : public ConversionPattern {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
context) {}
struct OpOrInterfaceConversionPatternBase : public ConversionPattern {
using ConversionPattern::ConversionPattern;
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
@ -419,6 +407,39 @@ struct OpConversionPattern : public ConversionPattern {
private:
using ConversionPattern::matchAndRewrite;
};
} // namespace detail
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
struct OpConversionPattern
: public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceConversionPatternBase<SourceOp>(
SourceOp::getOperationName(), benefit, context) {}
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: detail::OpOrInterfaceConversionPatternBase<SourceOp>(
typeConverter, SourceOp::getOperationName(), benefit, context) {}
};
/// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
/// allows for matching and rewriting against an instance of an OpInterface
/// class as opposed to a raw Operation.
template <typename SourceOp>
struct OpInterfaceConversionPattern
: public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceConversionPatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
benefit, context) {}
OpInterfaceConversionPattern(TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceConversionPatternBase<SourceOp>(
typeConverter, Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
};
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionLike op with the given type converter. This only supports

View File

@ -101,9 +101,9 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
}
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp op, PatternRewriter &rewriter) const {
// Only LinalgOp for which there is no specialized pattern go through this.
if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
return failure();
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
@ -199,8 +199,8 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
patterns.add<
CopyOpToLibraryCallRewrite,
CopyTransposeRewrite,
IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
patterns.add<LinalgOpToLibraryCallRewrite>();
IndexedGenericOpToLibraryCallRewrite,
LinalgOpToLibraryCallRewrite>(patterns.getContext());
// clang-format on
}

View File

@ -450,7 +450,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter,
PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, typeConverter, context) {}
: ConversionPattern(typeConverter, rootOpName, benefit, context) {}
//===----------------------------------------------------------------------===//
// StructBuilder implementation

View File

@ -2366,16 +2366,12 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
//===----------------------------------------------------------------------===//
namespace {
struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
for (Value v : linalgOp.getShapedOperands()) {
for (Value v : op.getShapedOperands()) {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
@ -2383,7 +2379,7 @@ struct EraseDeadLinalgOp : public RewritePattern {
if (!mt)
continue;
if (llvm::is_contained(mt.getShape(), 0)) {
rewriter.eraseOp(linalgOp);
rewriter.eraseOp(op);
return success();
}
}
@ -2391,19 +2387,14 @@ struct EraseDeadLinalgOp : public RewritePattern {
}
};
struct FoldTensorCastOp : public RewritePattern {
FoldTensorCastOp(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
llvm::any_of(op.getShapedOperands(), [&](Value v) {
if (v.isa<BlockArgument>())
return false;
auto castOp = v.getDefiningOp<tensor::CastOp>();
@ -2417,23 +2408,23 @@ struct FoldTensorCastOp : public RewritePattern {
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
for (Value v : linalgOp.getInputs()) {
for (Value v : op.getInputs()) {
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
newOperands.push_back(
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
}
// Init tensors may fold, in which case the resultType must also change.
for (Value v : linalgOp.getOutputs()) {
for (Value v : op.getOutputs()) {
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
newResultTypes.push_back(newOperands.back().getType());
}
auto extraOperands = linalgOp.getAssumedNonShapedOperands();
auto extraOperands = op.getAssumedNonShapedOperands();
newOperands.append(extraOperands.begin(), extraOperands.end());
// Clone op.
Operation *newOp =
linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
@ -2500,17 +2491,15 @@ struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
namespace {
// Deduplicate redundant args of a linalg op.
// An arg is redundant if it has the same Value and indexing map as another.
struct DeduplicateInputs : public RewritePattern {
DeduplicateInputs(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
// This pattern reduces the number of arguments of an op, which breaks
// the invariants of semantically charged named ops.
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
auto linalgOp = cast<LinalgOp>(op);
// Associate each input to an equivalent "canonical" input that has the same
// Value and indexing map.
@ -2524,9 +2513,9 @@ struct DeduplicateInputs : public RewritePattern {
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
SmallVector<int, 6> canonicalInputIndices;
for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) {
Value input = linalgOp.getInput(i);
AffineMap indexingMap = linalgOp.getInputIndexingMap(i);
for (int i = 0, e = op.getNumInputs(); i != e; i++) {
Value input = op.getInput(i);
AffineMap indexingMap = op.getInputIndexingMap(i);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
@ -2535,20 +2524,20 @@ struct DeduplicateInputs : public RewritePattern {
}
// If there are no duplicate args, then bail out.
if (canonicalInput.size() == linalgOp.getNumInputs())
if (canonicalInput.size() == op.getNumInputs())
return failure();
// The operands for the newly canonicalized op.
SmallVector<Value, 6> newOperands;
for (auto v : llvm::enumerate(linalgOp.getInputs()))
for (auto v : llvm::enumerate(op.getInputs()))
if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
newOperands.push_back(v.value());
llvm::append_range(newOperands, linalgOp.getOutputs());
llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
llvm::append_range(newOperands, op.getOutputs());
llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
// Clone the old op with new operands.
Operation *newOp = linalgOp.clone(rewriter, op->getLoc(),
op->getResultTypes(), newOperands);
Operation *newOp =
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
auto newLinalgOp = cast<LinalgOp>(newOp);
// Repair the indexing maps by filtering out the ones that have been
@ -2573,7 +2562,7 @@ struct DeduplicateInputs : public RewritePattern {
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
for (int i = 0, e = op.getNumInputs(); i < e; i++) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
@ -2597,13 +2586,12 @@ struct DeduplicateInputs : public RewritePattern {
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
struct RemoveIdentityLinalgOps : public RewritePattern {
RemoveIdentityLinalgOps(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
if (auto copyOp = dyn_cast<CopyOp>(op)) {
if (auto copyOp = dyn_cast<CopyOp>(*op)) {
assert(copyOp.hasBufferSemantics());
if (copyOp.input() == copyOp.output() &&
copyOp.inputPermutation() == copyOp.outputPermutation()) {
@ -2614,11 +2602,10 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
LinalgOp genericOp = cast<LinalgOp>(op);
if (!genericOp.hasTensorSemantics())
if (!op.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
if (llvm::any_of(op.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
@ -2633,7 +2620,7 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables();
unsigned numIndexArgs = op.getNumPayloadInductionVariables();
SmallVector<Value, 4> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
@ -2644,9 +2631,9 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
return failure();
returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
}
if (returnedArgs.size() != genericOp.getOperation()->getNumResults())
if (returnedArgs.size() != op.getOperation()->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);
rewriter.replaceOp(op, returnedArgs);
return success();
}
};
@ -2656,8 +2643,7 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(); \
results.add<ReplaceDimOfLinalgOpResult>(context); \
RemoveIdentityLinalgOps, ReplaceDimOfLinalgOpResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \

View File

@ -175,17 +175,15 @@ public:
/// Generic conversion pattern that matches any LinalgOp. This avoids template
/// instantiating one pattern for each LinalgOp.
class BufferizeAnyLinalgOp : public ConversionPattern {
class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
public:
BufferizeAnyLinalgOp(TypeConverter &typeConverter)
: ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
// GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
if (!op->hasAttr("operand_segment_sizes"))
return failure();
// We abuse the GenericOpAdaptor here.
@ -193,32 +191,30 @@ public:
// linalg::LinalgOp interface ops.
linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
Location loc = linalgOp.getLoc();
Location loc = op.getLoc();
SmallVector<Value, 2> newOutputBuffers;
if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(),
if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
newOutputBuffers, rewriter))) {
linalgOp.emitOpError()
<< "Failed to allocate buffers for tensor results.";
return failure();
return op.emitOpError()
<< "Failed to allocate buffers for tensor results.";
}
// Delegate to the linalg generic pattern.
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
finalizeBufferAllocationForGenericOp<GenericOp>(
rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
return success();
}
// Delegate to the linalg indexed generic pattern.
if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(*op)) {
finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
return success();
}
finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
newOutputBuffers);
finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
return success();
}
};
@ -338,10 +334,10 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeAnyLinalgOp>(typeConverter);
// TODO: Drop this once tensor constants work in standard.
// clang-format off
patterns.add<
BufferizeAnyLinalgOp,
BufferizeFillOp,
BufferizeInitTensorOp,
SubTensorOpConverter,

View File

@ -83,7 +83,7 @@ public:
struct FunctionNonEntryBlockConversion : public ConversionPattern {
FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
: ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,

View File

@ -75,8 +75,8 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
namespace {
struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors()
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (!isElementwiseMappableOpOnRankedTensors(op))
@ -117,7 +117,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
void mlir::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>();
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
}
namespace {

View File

@ -104,7 +104,7 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()),
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
marker(std::move(marker)) {}
LogicalResult matchAndRewrite(Operation *rootOp,

View File

@ -520,8 +520,9 @@ namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
LinalgRewritePattern(ArrayRef<unsigned> interchangeVector)
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
LinalgRewritePattern(MLIRContext *context,
ArrayRef<unsigned> interchangeVector)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult matchAndRewrite(Operation *op,
@ -546,7 +547,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(interchangeVector);
patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);

View File

@ -234,13 +234,13 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), filter(filter),
: RewritePattern(opName, benefit, context), filter(filter),
options(options) {}
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
LinalgTilingOptions options, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter),
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
options(options) {}
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
@ -306,7 +306,7 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context),
: RewritePattern(opName, benefit, context, {}),
dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
fusionOptions(fusionOptions), filter(filter),
fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
@ -401,7 +401,7 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), filter(filter),
: RewritePattern(opName, benefit, context, {}), filter(filter),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
@ -427,7 +427,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), filter(filter),
: RewritePattern(opName, benefit, context, {}), filter(filter),
options(options) {}
LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
@ -453,13 +453,14 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
}
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), filter(filter) {}
: RewritePattern(opName, benefit, context, {}), filter(filter) {}
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {

View File

@ -46,25 +46,21 @@ namespace {
/// Only needed to support partial conversion of functions where this pattern
/// ensures that the branch operation arguments matches up with the succesor
/// block arguments.
class BranchOpInterfaceTypeConversion : public ConversionPattern {
class BranchOpInterfaceTypeConversion
: public OpInterfaceConversionPattern<BranchOpInterface> {
public:
BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
MLIRContext *ctx)
: ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
using OpInterfaceConversionPattern<
BranchOpInterface>::OpInterfaceConversionPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto branchOp = dyn_cast<BranchOpInterface>(op);
if (!branchOp)
return failure();
// For a branch operation, only some operands go to the target blocks, so
// only rewrite those.
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
auto successorOperands = branchOp.getSuccessorOperands(succIdx);
auto successorOperands = op.getSuccessorOperands(succIdx);
if (!successorOperands)
continue;
for (int idx = successorOperands->getBeginOperandIndex(),

View File

@ -29,23 +29,49 @@ unsigned short PatternBenefit::getBenefit() const {
// Pattern
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// OperationName Root Constructors
Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context, ArrayRef<StringRef> generatedNames)
: Pattern(OperationName(rootName, context).getAsOpaquePointer(),
RootKind::OperationName, generatedNames, benefit, context) {}
//===----------------------------------------------------------------------===//
// MatchAnyOpTypeTag Root Constructors
Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
MLIRContext *context, ArrayRef<StringRef> generatedNames)
: Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
//===----------------------------------------------------------------------===//
// MatchInterfaceOpTypeTag Root Constructors
Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
PatternBenefit benefit, MLIRContext *context,
ArrayRef<StringRef> generatedNames)
: Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
generatedNames, benefit, context) {}
//===----------------------------------------------------------------------===//
// MatchTraitOpTypeTag Root Constructors
Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
PatternBenefit benefit, MLIRContext *context,
ArrayRef<StringRef> generatedNames)
: Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
benefit, context) {}
//===----------------------------------------------------------------------===//
// General Constructors
Pattern::Pattern(const void *rootValue, RootKind rootKind,
ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context)
: rootKind(OperationName(rootName, context)), benefit(benefit) {}
Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
: benefit(benefit) {}
Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
: Pattern(rootName, benefit, context) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {
: rootValue(rootValue), rootKind(rootKind), benefit(benefit),
contextAndHasBoundedRecursion(context, false) {
if (generatedNames.empty())
return;
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {

View File

@ -45,10 +45,10 @@ PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
// Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.rootKind())
return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
ctx);
return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
MatchAnyOpTypeTag());
return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
generatedOps);
return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
generatedOps);
}
//===----------------------------------------------------------------------===//

View File

@ -55,7 +55,43 @@ FrozenRewritePatternSet::FrozenRewritePatternSet()
FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
: impl(std::make_shared<Impl>()) {
impl->nativePatterns = std::move(patterns.getNativePatterns());
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
std::vector<AbstractOperation *> abstractOps;
auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
function_ref<bool(AbstractOperation *)> callbackFn) {
if (abstractOps.empty())
abstractOps = pattern->getContext()->getRegisteredOperations();
for (AbstractOperation *absOp : abstractOps) {
if (callbackFn(absOp)) {
OperationName opName(absOp);
impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
}
}
impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
};
for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
if (Optional<OperationName> rootName = pat->getRootKind()) {
impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
impl->nativeOpSpecificPatternList.push_back(std::move(pat));
continue;
}
if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
addToOpsWhen(pat, [&](AbstractOperation *absOp) {
return absOp->hasInterface(*interfaceID);
});
continue;
}
if (Optional<TypeID> traitID = pat->getRootTraitID()) {
addToOpsWhen(pat, [&](AbstractOperation *absOp) {
return absOp->hasTrait(*traitID);
});
continue;
}
impl->nativeAnyOpPatterns.push_back(std::move(pat));
}
// Generate the bytecode for the PDL patterns if any were provided.
PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();

View File

@ -15,6 +15,8 @@
#include "ByteCode.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pattern-match"
using namespace mlir;
using namespace mlir::detail;
@ -28,7 +30,14 @@ PatternApplicator::PatternApplicator(
}
PatternApplicator::~PatternApplicator() {}
#define DEBUG_TYPE "pattern-match"
/// Log a message for a pattern that is impossible to match.
static void logImpossibleToMatch(const Pattern &pattern) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
}
void PatternApplicator::applyCostModel(CostModel model) {
// Apply the cost model to the bytecode patterns first, and then the native
@ -38,23 +47,24 @@ void PatternApplicator::applyCostModel(CostModel model) {
mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
}
// Separate patterns by root kind to simplify lookup later on.
// Copy over the patterns so that we can sort by benefit based on the cost
// model. Patterns that are already impossible to match are ignored.
patterns.clear();
anyOpPatterns.clear();
for (const auto &pat : frozenPatternList.getNativePatterns()) {
// If the pattern is always impossible to match, just ignore it.
if (pat.getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs()
<< "Ignoring pattern '" << pat.getRootKind()
<< "' because it is impossible to match (by pattern benefit)\n";
});
continue;
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
for (const RewritePattern *pattern : it.second) {
if (pattern->getBenefit().isImpossibleToMatch())
logImpossibleToMatch(*pattern);
else
patterns[it.first].push_back(pattern);
}
if (Optional<OperationName> opName = pat.getRootKind())
patterns[*opName].push_back(&pat);
}
anyOpPatterns.clear();
for (const RewritePattern &pattern :
frozenPatternList.getMatchAnyOpNativePatterns()) {
if (pattern.getBenefit().isImpossibleToMatch())
logImpossibleToMatch(pattern);
else
anyOpPatterns.push_back(&pat);
anyOpPatterns.push_back(&pattern);
}
// Sort the patterns using the provided cost model.
@ -66,11 +76,7 @@ void PatternApplicator::applyCostModel(CostModel model) {
// Special case for one pattern in the list, which is the most common case.
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
logImpossibleToMatch(*list.front());
list.clear();
}
return;
@ -84,14 +90,8 @@ void PatternApplicator::applyCostModel(CostModel model) {
// Sort patterns with highest benefit first, and remove those that are
// impossible to match.
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
<< "' because it is impossible to match or cannot lead to "
"legal IR (by cost model)\n";
});
list.pop_back();
}
while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
logImpossibleToMatch(*list.pop_back_val());
};
for (auto &it : patterns)
processPatternList(it.second);
@ -100,7 +100,10 @@ void PatternApplicator::applyCostModel(CostModel model) {
void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
for (const Pattern &it : frozenPatternList.getNativePatterns())
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
for (const auto &pattern : it.second)
walk(*pattern);
for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
walk(it);
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
for (const Pattern &it : bytecode->getPatterns())

View File

@ -2582,7 +2582,7 @@ namespace {
struct FunctionLikeSignatureConversion : public ConversionPattern {
FunctionLikeSignatureConversion(StringRef functionLikeOpName,
MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
: ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
/// Hook to implement combined matching and rewriting for FunctionLike ops.
LogicalResult

View File

@ -149,8 +149,8 @@ void ConvertToTargetEnv::runOnFunction() {
}
ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
: RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
{"spv.AtomicCompareExchangeWeak"}, 1, context) {}
: RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
context, {"spv.AtomicCompareExchangeWeak"}) {}
LogicalResult
ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
@ -170,8 +170,8 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
}
ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
: RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
context) {}
: RewritePattern("test.convert_to_bit_reverse_op", 1, context,
{"spv.BitReverse"}) {}
LogicalResult
ConvertToBitReverse::matchAndRewrite(Operation *op,
@ -185,8 +185,8 @@ ConvertToBitReverse::matchAndRewrite(Operation *op,
ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
MLIRContext *context)
: RewritePattern("test.convert_to_group_non_uniform_ballot_op",
{"spv.GroupNonUniformBallot"}, 1, context) {}
: RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
{"spv.GroupNonUniformBallot"}) {}
LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
@ -198,7 +198,7 @@ LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
}
ConvertToModule::ConvertToModule(MLIRContext *context)
: RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
: RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
LogicalResult
ConvertToModule::matchAndRewrite(Operation *op,
@ -210,8 +210,8 @@ ConvertToModule::matchAndRewrite(Operation *op,
}
ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
: RewritePattern("test.convert_to_subgroup_ballot_op",
{"spv.SubgroupBallotKHR"}, 1, context) {}
: RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
{"spv.SubgroupBallotKHR"}) {}
LogicalResult
ConvertToSubgroupBallot::matchAndRewrite(Operation *op,

View File

@ -325,7 +325,7 @@ struct TestUndoBlockErase : public ConversionPattern {
/// This patterns erases a region operation that has had a type conversion.
struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
: ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@ -726,7 +726,8 @@ struct TestRemappedValue
namespace {
/// This pattern matches and removes any operation in the test dialect.
struct RemoveTestDialectOps : public RewritePattern {
RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
RemoveTestDialectOps(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@ -741,7 +742,7 @@ struct TestUnknownRootOpDriver
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
void runOnFunction() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<RemoveTestDialectOps>();
patterns.add<RemoveTestDialectOps>(&getContext());
mlir::ConversionTarget target(getContext());
target.addIllegalDialect<TestDialect>();

View File

@ -183,8 +183,8 @@ static void applyPatterns(FuncOp funcOp) {
// Linalg to vector contraction patterns.
//===--------------------------------------------------------------------===//
patterns.add<LinalgVectorizationPattern>(
LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
@ -258,8 +258,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
LinalgTransformationFilter(Identifier::get("VEC", ctx))));
patternsVector.back().add<LinalgVectorizationPattern>(
LinalgTransformationFilter().addFilter(
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
ctx, LinalgTransformationFilter().addFilter(
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
}
//===----------------------------------------------------------------------===//
@ -496,6 +496,7 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<LinalgVectorizationPattern>(
funcOp.getContext(),
LinalgTransformationFilter()
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());

View File

@ -2075,8 +2075,8 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>();
results.add<FoldTensorCastOp>();
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{

View File

@ -521,8 +521,8 @@ const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>();
results.add<FoldTensorCastOp>();
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{

View File

@ -626,8 +626,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
{0}(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("{1}", {{)",
rewriteName, rootName);
: ::mlir::RewritePattern("{1}", {2}, context, {{)",
rewriteName, rootName, pattern.getBenefit());
// Sort result operators by name.
llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
resultOps.end());
@ -637,7 +637,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
os << '"' << op->getOperationName() << '"';
});
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
os << "}) {}\n";
// Emit matchAndRewrite() function.
{

View File

@ -38,8 +38,9 @@ TEST(PatternBenefitTest, BenefitOrder) {
};
struct Pattern2 : public RewritePattern {
Pattern2(bool *called)
: RewritePattern(/*benefit*/ 2, MatchAnyOpTypeTag{}), called(called) {}
Pattern2(MLIRContext *context, bool *called)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
called(called) {}
mlir::LogicalResult
matchAndRewrite(Operation * /*op*/,
@ -58,7 +59,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
bool called2 = false;
patterns.add<Pattern1>(&context, &called1);
patterns.add<Pattern2>(&called2);
patterns.add<Pattern2>(&context, &called2);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
PatternApplicator pa(frozenPatterns);