Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.

--

PiperOrigin-RevId: 250003405
This commit is contained in:
River Riddle 2019-05-25 17:22:27 -07:00 committed by Mehdi Amini
parent 2f50b6c401
commit 9e21ab8f52
9 changed files with 215 additions and 250 deletions

View File

@ -248,12 +248,12 @@ namespace {
/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
/// fold away all the SliceOp.
template <typename LoadOrStoreOpTy> struct Rewriter : public RewritePattern {
explicit Rewriter(MLIRContext *context)
: RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {}
template <typename LoadOrStoreOpTy>
struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
/// Performs the rewrite.
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
PatternRewriter &rewriter) const override;
};
@ -270,9 +270,8 @@ struct LowerLinalgLoadStorePass
template <>
PatternMatchResult
Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
PatternRewriter &rewriter) const {
auto load = cast<linalg::LoadOp>(op);
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(load.getView()->getDefiningOp());
@ -280,15 +279,14 @@ Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
ScopedContext scope(builder, load.getLoc());
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(load, view);
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
return matchSuccess();
}
template <>
PatternMatchResult
Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
PatternRewriter &rewriter) const {
auto store = cast<linalg::StoreOp>(op);
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(store.getView()->getDefiningOp());
@ -297,7 +295,7 @@ Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
auto *valueToStore = store.getValueToStore();
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(store, view);
rewriter.replaceOpWithNewOp<mlir::StoreOp>(op, valueToStore, memRef,
rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
operands);
return matchSuccess();
}

View File

@ -33,25 +33,21 @@ namespace toy {
namespace {
/// Fold transpose(transpose(x) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
context) {}
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// We can directly cast the current operation as this will only get invoked
// on TransposeOp.
TransposeOp transpose = llvm::cast<TransposeOp>(op);
// Look through the input of the current transpose.
mlir::Value *transposeInput = transpose.getOperand();
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
@ -65,15 +61,12 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
struct SimplifyReshapeConstant : public mlir::RewritePattern {
SimplifyReshapeConstant(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
@ -81,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
if (!constantOp)
return matchFailure();
auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
@ -90,7 +83,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
@ -102,7 +95,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
@ -112,17 +105,15 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
};
/// Fold reshape(reshape(x)) -> reshape(x)
struct SimplifyReshapeReshape : public mlir::RewritePattern {
SimplifyReshapeReshape(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
mlir::Value *reshapeInput = reshape.getOperand();
mlir::Value *reshapeInput = op.getOperand();
// If the input is defined by another reshape, bingo!
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
@ -134,18 +125,15 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern {
};
/// Fold reshape(x)) -> x, when input type matches output type
struct SimplifyNullReshape : public mlir::RewritePattern {
SimplifyNullReshape(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
if (op.getOperand()->getType() != op.getType())
return matchFailure();
rewriter.replaceOp(reshape, {reshape.getOperand()});
rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};

View File

@ -22,7 +22,7 @@
#include "toy/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@ -32,30 +32,26 @@ namespace toy {
namespace {
/// Fold transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
/// Fold transpose(transpose(x) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
context) {}
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// We can directly cast the current operation as this will only get invoked
// on TransposeOp.
TransposeOp transpose = llvm::cast<TransposeOp>(op);
// look through the input to the current transpose
mlir::Value *transposeInput = transpose.getOperand();
mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
// If the input is defined by another Transpose, bingo!
// Look through the input of the current transpose.
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
@ -66,25 +62,21 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
struct SimplifyReshapeConstant : public mlir::RewritePattern {
SimplifyReshapeConstant(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// look through the input to the current reshape
mlir::Value *reshapeInput = reshape.getOperand();
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
// If the input is defined by another reshape, bingo!
ConstantOp constantOp =
mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
// If the input is defined by another constant, bingo!
if (!constantOp)
return matchFailure();
auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
@ -93,9 +85,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
auto newConstant = rewriter.create<ConstantOp>(
constantOp.getLoc(), reshapeType.getShape(), newAttr);
rewriter.replaceOp(op, {newConstant});
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
// Broadcast
@ -106,9 +97,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
auto newConstant = rewriter.create<ConstantOp>(
constantOp.getLoc(), reshapeType.getShape(), newAttr);
rewriter.replaceOp(op, {newConstant});
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
}
@ -117,43 +107,35 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
};
/// Fold reshape(reshape(x)) -> reshape(x)
struct SimplifyReshapeReshape : public mlir::RewritePattern {
SimplifyReshapeReshape(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// look through the input to the current reshape
mlir::Value *reshapeInput = reshape.getOperand();
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
// Look through the input of the current reshape.
mlir::Value *reshapeInput = op.getOperand();
// If the input is defined by another reshape, bingo!
ReshapeOp reshapeInputOp =
mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
if (!reshapeInputOp)
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {reshapeInputOp});
rewriter.replaceOp(op, {reshapeInput});
return matchSuccess();
}
};
/// Fold reshape(x)) -> x, when input type matches output type
struct SimplifyNullReshape : public mlir::RewritePattern {
SimplifyNullReshape(mlir::MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
if (op.getOperand()->getType() != op.getType())
return matchFailure();
rewriter.replaceOp(reshape, {reshape.getOperand()});
rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};
@ -176,17 +158,14 @@ void ReshapeOp::getCanonicalizationPatterns(
namespace {
/// Fold type.cast(x) -> x, when input type matches output type
struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
SimplifyIdentityTypeCast(mlir::MLIRContext *context)
: RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
context) {}
struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
matchAndRewrite(TypeCastOp typeCast,
mlir::PatternRewriter &rewriter) const override {
TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
auto resTy = typeCast.getResult()->getType();
auto *candidateOp = op;
auto resTy = typeCast.getType();
auto *candidateOp = typeCast.getOperation();
while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
if (resTy == candidateOp->getOperand(0)->getType()) {
rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});

View File

@ -205,6 +205,53 @@ protected:
llvm::SmallVector<OperationName, 2> generatedOps;
};
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
}
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(llvm::cast<SourceOp>(op), rewriter);
}
PatternMatchResult match(Operation *op) const final {
return match(llvm::cast<SourceOp>(op));
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
rewrite(op, rewriter);
}
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual PatternMatchResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual PatternMatchResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (auto matchResult = match(op)) {
rewrite(op, std::move(*matchResult), rewriter);
return matchSuccess();
}
return matchFailure();
}
};
//===----------------------------------------------------------------------===//
// PatternRewriter class
//===----------------------------------------------------------------------===//

View File

@ -654,24 +654,21 @@ void mlir::canonicalizeMapAndOperands(
namespace {
/// Simplify AffineApply operations.
///
struct SimplifyAffineApply : public RewritePattern {
SimplifyAffineApply(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(AffineApplyOp apply,
PatternRewriter &rewriter) const override {
auto apply = cast<AffineApplyOp>(op);
auto map = apply.getAffineMap();
AffineMap oldMap = map;
SmallVector<Value *, 8> resultOperands(apply.getOperands());
composeAffineMapAndOperands(&map, &resultOperands);
if (map != oldMap) {
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, map, resultOperands);
return matchSuccess();
}
if (map == oldMap)
return matchFailure();
return matchFailure();
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
return matchSuccess();
}
};
} // end anonymous namespace.
@ -1002,14 +999,11 @@ void AffineForOp::print(OpAsmPrinter *p) {
namespace {
/// This is a pattern to fold constant loop bounds.
struct AffineForLoopBoundFolder : public RewritePattern {
/// The rootOpName is the name of the root operation to match against.
AffineForLoopBoundFolder(MLIRContext *context)
: RewritePattern(AffineForOp::getOperationName(), 1, context) {}
struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
auto forOp = cast<AffineForOp>(op);
auto foldLowerOrUpperBound = [&forOp](bool lower) {
// Check to see if each of the operands is the result of a constant. If
// so, get the value. If not, ignore it.
@ -1056,7 +1050,7 @@ struct AffineForLoopBoundFolder : public RewritePattern {
// If any of the bounds were folded we return success.
if (!folded)
return matchFailure();
rewriter.updatedRootInPlace(op);
rewriter.updatedRootInPlace(forOp);
return matchSuccess();
}
};

View File

@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input,
namespace {
struct UniformDequantizePattern : public RewritePattern {
UniformDequantizePattern(MLIRContext *context)
: RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
auto dcastOp = cast<DequantizeCastOp>(op);
Type inputType = dcastOp.arg()->getType();
Type outputType = dcastOp.getResult()->getType();
Type inputType = op.arg()->getType();
Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern {
return matchFailure();
}
Value *dequantizedValue =
emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
namespace {
struct UniformRealAddEwPattern : public RewritePattern {
UniformRealAddEwPattern(MLIRContext *context)
: RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
auto addOp = cast<RealAddEwOp>(op);
const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
addOp.clamp_min(), addOp.clamp_max());
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern {
}
};
struct UniformRealMulEwPattern : public RewritePattern {
UniformRealMulEwPattern(MLIRContext *context)
: RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
auto mulOp = cast<RealMulEwOp>(op);
const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
mulOp.clamp_min(), mulOp.clamp_max());
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}

View File

@ -38,26 +38,21 @@ namespace {
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
class RemoveRedundantStorageCastsRewrite : public RewritePattern {
class RemoveRedundantStorageCastsRewrite
: public OpRewritePattern<StorageCastOp> {
public:
RemoveRedundantStorageCastsRewrite(MLIRContext *context)
: RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
using OpRewritePattern<StorageCastOp>::OpRewritePattern;
PatternMatchResult match(Operation *op) const override {
auto scastOp = cast<StorageCastOp>(op);
if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
return matchSuccess();
}
}
return matchFailure();
}
PatternMatchResult matchAndRewrite(StorageCastOp op,
PatternRewriter &rewriter) const override {
if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
return matchFailure();
auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
if (srcScastOp.arg()->getType() != op.getType())
return matchFailure();
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto scastOp = cast<StorageCastOp>(op);
auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
rewriter.replaceOp(op, srcScastOp.arg());
return matchSuccess();
}
};

View File

@ -36,40 +36,35 @@ public:
void runOnFunction() override;
};
class QuantizedConstRewrite : public RewritePattern {
public:
struct State : PatternState {
QuantizedType quantizedElementType;
Attribute value;
};
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
QuantizedConstRewrite(MLIRContext *context)
: RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override;
void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
PatternRewriter &rewriter) const override;
PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
State state;
PatternMatchResult
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const {
Attribute value;
// Is the operand a constant?
auto qbarrier = cast<QuantizeCastOp>(op);
if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
}
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
Type qbarrierResultType = qbarrier.getResult()->getType();
state.quantizedElementType =
QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
if (!state.quantizedElementType) {
if (!quantizedElementType) {
return matchFailure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
// Is the operand type compatible with the expressed type of the quantized
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
if (!state.quantizedElementType.isCompatibleExpressedType(
if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg()->getType())) {
return matchFailure();
}
// Is the constant value a type expressed in a way that we support?
if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
!state.value.isa<DenseElementsAttr>() &&
!state.value.isa<SparseElementsAttr>()) {
if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
!value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
return matchFailure();
}
return matchSuccess(llvm::make_unique<State>(std::move(state)));
}
void QuantizedConstRewrite::rewrite(Operation *op,
std::unique_ptr<PatternState> baseState,
PatternRewriter &rewriter) const {
auto state = static_cast<State *>(baseState.get());
Type newConstValueType;
Attribute newConstValue = quantizeAttr(
state->value, state->quantizedElementType, newConstValueType);
auto newConstValue =
quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
return;
return matchFailure();
}
auto *origConstOp = op->getOperand(0);
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
auto fusedLoc =
FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
rewriter.getContext());
auto fusedLoc = FusedLoc::get(
{qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>(
{origConstOp}, op, *op->result_type_begin(), newConstOp);
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
qbarrier.getType(), newConstOp);
return matchSuccess();
}
void ConvertConstPass::runOnFunction() {

View File

@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) {
namespace {
/// Fold constant dimensions into an alloc operation.
struct SimplifyAllocConst : public RewritePattern {
SimplifyAllocConst(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
auto alloc = cast<AllocOp>(op);
struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
for (auto *operand : alloc.getOperands())
if (matchPattern(operand, m_ConstantIndex()))
return matchSuccess();
return matchFailure();
}
if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return matchFailure();
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto allocOp = cast<AllocOp>(op);
auto memrefType = allocOp.getType();
auto memrefType = alloc.getType();
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern {
newShapeConstants.push_back(dimSize);
continue;
}
auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp();
auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern {
} else {
// Dynamic shape dimension not folded; copy operand from old memref.
newShapeConstants.push_back(-1);
newOperands.push_back(allocOp.getOperand(dynamicDimPos));
newOperands.push_back(alloc.getOperand(dynamicDimPos));
}
dynamicDimPos++;
}
@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern {
// Create and insert the alloc op for the new memref.
auto newAlloc =
rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands);
rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
// Insert a cast so we have the same type as the old alloc.
auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc,
allocOp.getType());
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
alloc.getType());
rewriter.replaceOp(op, {resultCast}, droppedOperands);
rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
return matchSuccess();
}
};
/// Fold alloc operations with no uses. Alloc has side effects on the heap,
/// but can still be deleted if it has zero uses.
struct SimplifyDeadAlloc : public RewritePattern {
SimplifyDeadAlloc(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
// Check if the alloc'ed value has any uses.
auto alloc = cast<AllocOp>(op);
if (!alloc.use_empty())
return matchFailure();
// If it doesn't, we can eliminate it.
op->erase();
alloc.erase();
return matchSuccess();
}
};
@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() {
//===----------------------------------------------------------------------===//
namespace {
/// Fold indirect calls that have a constant function as the callee operand.
struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
: RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
struct SimplifyIndirectCallWithKnownCallee
: public OpRewritePattern<CallIndirectOp> {
using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
auto indirectCall = cast<CallIndirectOp>(op);
// Check that the callee is a constant callee.
FunctionAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
// Replace with a direct call.
SmallVector<Type, 8> callResults(op->getResultTypes());
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
callOperands);
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
callResults, callOperands);
return matchSuccess();
}
};
@ -964,14 +956,11 @@ namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
auto condbr = cast<CondBranchOp>(op);
// Check that the condition is a constant.
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
return matchFailure();
@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
return matchSuccess();
}
};
@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
namespace {
/// Fold Dealloc operations that are deallocating an AllocOp that is only used
/// by other Dealloc operations.
struct SimplifyDeadDealloc : public RewritePattern {
SimplifyDeadDealloc(MLIRContext *context)
: RewritePattern(DeallocOp::getOperationName(), 1, context) {}
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
auto dealloc = cast<DeallocOp>(op);
// Check that the memref operand's defining operation is an AllocOp.
Value *memref = dealloc.memref();
Operation *defOp = memref->getDefiningOp();
if (!isa_and_nonnull<AllocOp>(defOp))
if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
return matchFailure();
// Check that all of the uses of the AllocOp are other DeallocOps.
@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern {
return matchFailure();
// Erase the dealloc operation.
op->erase();
rewriter.replaceOp(dealloc, llvm::None);
return matchSuccess();
}
};