From 6090643877fde4d7f00c25e0d78e1f7eec4f3cdb Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 17 Oct 2019 15:18:31 -0700 Subject: [PATCH] Introduce a wrapper around ConversionPattern that operates on the derived class Analogous to OpRewritePattern, this makes writing conversion patterns more convenient. PiperOrigin-RevId: 275349854 --- .../mlir/Transforms/DialectConversion.h | 76 +++++++++++++++++++ mlir/lib/Transforms/DialectConversion.cpp | 10 +-- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 2c22911f9e50..fb44764f127a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -223,6 +223,82 @@ private: using RewritePattern::rewrite; }; +/// OpConversionPattern is a wrapper around ConversionPattern that allows for +/// matching and rewriting against an instance of a derived operation class as +/// opposed to a raw Operation. +template +struct OpConversionPattern : public ConversionPattern { + OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} + + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(llvm::cast(op), operands, rewriter); + } + void rewrite(Operation *op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(llvm::cast(op), properOperands, destinations, operands, + rewriter); + } + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(llvm::cast(op), properOperands, + destinations, operands, rewriter); + } + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(llvm::cast(op), operands, rewriter); + } + + // TODO(b/142763075): Use OperandAdaptor when it supports access to unnamed + // operands. + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + + virtual void rewrite(SourceOp op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("unimplemented rewrite for terminators"); + } + + virtual PatternMatchResult + matchAndRewrite(SourceOp op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + ConversionPatternRewriter &rewriter) const { + if (!match(op)) + return matchFailure(); + rewrite(op, properOperands, destinations, operands, rewriter); + return matchSuccess(); + } + + virtual PatternMatchResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!match(op)) + return matchFailure(); + rewrite(op, operands, rewriter); + return matchSuccess(); + } + +private: + using ConversionPattern::matchAndRewrite; +}; + /// Add a pattern to the given pattern list to convert the signature of a FuncOp /// with the given type converter. void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 433986c65576..e82100c0d392 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1393,16 +1393,14 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { -struct FuncOpSignatureConversion : public ConversionPattern { +struct FuncOpSignatureConversion : public OpConversionPattern { FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern(FuncOp::getOperationName(), 1, ctx), - converter(converter) {} + : OpConversionPattern(ctx), converter(converter) {} /// Hook for derived classes to implement combined matching and rewriting. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); FunctionType type = funcOp.getType(); // Convert the original function arguments. @@ -1425,7 +1423,7 @@ struct FuncOpSignatureConversion : public ConversionPattern { // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); - rewriter.eraseOp(op); + rewriter.eraseOp(funcOp); return matchSuccess(); }