Introduce a wrapper around ConversionPattern that operates on the derived class

Analogous to OpRewritePattern, this makes writing conversion patterns more convenient.

PiperOrigin-RevId: 275349854
This commit is contained in:
Geoffrey Martin-Noble 2019-10-17 15:18:31 -07:00 committed by A. Unique TensorFlower
parent b65c8bb5d6
commit 6090643877
2 changed files with 80 additions and 6 deletions

View File

@ -223,6 +223,82 @@ private:
using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
struct OpConversionPattern : public ConversionPattern {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), properOperands, destinations, operands,
rewriter);
}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(llvm::cast<SourceOp>(op), properOperands,
destinations, operands, rewriter);
}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(llvm::cast<SourceOp>(op), operands, rewriter);
}
// TODO(b/142763075): Use OperandAdaptor when it supports access to unnamed
// operands.
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}
virtual PatternMatchResult
matchAndRewrite(SourceOp op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, properOperands, destinations, operands, rewriter);
return matchSuccess();
}
virtual PatternMatchResult
matchAndRewrite(SourceOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, operands, rewriter);
return matchSuccess();
}
private:
using ConversionPattern::matchAndRewrite;
};
/// Add a pattern to the given pattern list to convert the signature of a FuncOp
/// with the given type converter.
void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,

View File

@ -1393,16 +1393,14 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
/// Create a default conversion pattern that rewrites the type signature of a
/// FuncOp.
namespace {
struct FuncOpSignatureConversion : public ConversionPattern {
struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(FuncOp::getOperationName(), 1, ctx),
converter(converter) {}
: OpConversionPattern(ctx), converter(converter) {}
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Convert the original function arguments.
@ -1425,7 +1423,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.eraseOp(op);
rewriter.eraseOp(funcOp);
return matchSuccess();
}