forked from OSchip/llvm-project
Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*.
-- PiperOrigin-RevId: 250003405
This commit is contained in:
parent
2f50b6c401
commit
9e21ab8f52
|
@ -248,12 +248,12 @@ namespace {
|
|||
/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
|
||||
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
|
||||
/// fold away all the SliceOp.
|
||||
template <typename LoadOrStoreOpTy> struct Rewriter : public RewritePattern {
|
||||
explicit Rewriter(MLIRContext *context)
|
||||
: RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {}
|
||||
template <typename LoadOrStoreOpTy>
|
||||
struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
|
||||
using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
|
||||
|
||||
/// Performs the rewrite.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
|
@ -270,9 +270,8 @@ struct LowerLinalgLoadStorePass
|
|||
|
||||
template <>
|
||||
PatternMatchResult
|
||||
Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
|
||||
Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto load = cast<linalg::LoadOp>(op);
|
||||
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
|
||||
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
||||
: cast<ViewOp>(load.getView()->getDefiningOp());
|
||||
|
@ -280,15 +279,14 @@ Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
|
|||
ScopedContext scope(builder, load.getLoc());
|
||||
auto *memRef = view.getSupportingMemRef();
|
||||
auto operands = emitAndReturnLoadStoreOperands(load, view);
|
||||
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
|
||||
rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
template <>
|
||||
PatternMatchResult
|
||||
Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
|
||||
Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto store = cast<linalg::StoreOp>(op);
|
||||
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
|
||||
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
||||
: cast<ViewOp>(store.getView()->getDefiningOp());
|
||||
|
@ -297,7 +295,7 @@ Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
|
|||
auto *valueToStore = store.getValueToStore();
|
||||
auto *memRef = view.getSupportingMemRef();
|
||||
auto operands = emitAndReturnLoadStoreOperands(store, view);
|
||||
rewriter.replaceOpWithNewOp<mlir::StoreOp>(op, valueToStore, memRef,
|
||||
rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
|
||||
operands);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
|
@ -33,25 +33,21 @@ namespace toy {
|
|||
namespace {
|
||||
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
/// them in order of profitability.
|
||||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// We can directly cast the current operation as this will only get invoked
|
||||
// on TransposeOp.
|
||||
TransposeOp transpose = llvm::cast<TransposeOp>(op);
|
||||
// Look through the input of the current transpose.
|
||||
mlir::Value *transposeInput = transpose.getOperand();
|
||||
mlir::Value *transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
|
@ -65,15 +61,12 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
|
|||
};
|
||||
|
||||
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
|
||||
struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
||||
SimplifyReshapeConstant(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp reshape,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
// Look through the input of the current reshape.
|
||||
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
|
||||
reshape.getOperand()->getDefiningOp());
|
||||
|
@ -81,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
if (!constantOp)
|
||||
return matchFailure();
|
||||
|
||||
auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
|
||||
auto reshapeType = reshape.getType().cast<ToyArrayType>();
|
||||
if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
|
||||
// FIXME Check matching of element count!
|
||||
|
@ -90,7 +83,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
|
||||
|
@ -102,7 +95,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
|
||||
reshapeType.getElementType());
|
||||
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else {
|
||||
llvm_unreachable("Unsupported Constant format");
|
||||
|
@ -112,17 +105,15 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
};
|
||||
|
||||
/// Fold reshape(reshape(x)) -> reshape(x)
|
||||
struct SimplifyReshapeReshape : public mlir::RewritePattern {
|
||||
SimplifyReshapeReshape(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
// Look through the input of the current reshape.
|
||||
mlir::Value *reshapeInput = reshape.getOperand();
|
||||
mlir::Value *reshapeInput = op.getOperand();
|
||||
|
||||
// If the input is defined by another reshape, bingo!
|
||||
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
|
||||
return matchFailure();
|
||||
|
@ -134,18 +125,15 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern {
|
|||
};
|
||||
|
||||
/// Fold reshape(x)) -> x, when input type matches output type
|
||||
struct SimplifyNullReshape : public mlir::RewritePattern {
|
||||
SimplifyNullReshape(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
|
||||
if (op.getOperand()->getType() != op.getType())
|
||||
return matchFailure();
|
||||
rewriter.replaceOp(reshape, {reshape.getOperand()});
|
||||
rewriter.replaceOp(op, {op.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
|
@ -32,30 +32,26 @@ namespace toy {
|
|||
|
||||
namespace {
|
||||
|
||||
/// Fold transpose(transpose(x)) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
/// them in order of profitability.
|
||||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// We can directly cast the current operation as this will only get invoked
|
||||
// on TransposeOp.
|
||||
TransposeOp transpose = llvm::cast<TransposeOp>(op);
|
||||
// look through the input to the current transpose
|
||||
mlir::Value *transposeInput = transpose.getOperand();
|
||||
mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
// Look through the input of the current transpose.
|
||||
mlir::Value *transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
|
@ -66,25 +62,21 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
|
|||
};
|
||||
|
||||
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
|
||||
struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
||||
SimplifyReshapeConstant(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp reshape,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
// look through the input to the current reshape
|
||||
mlir::Value *reshapeInput = reshape.getOperand();
|
||||
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
|
||||
// If the input is defined by another reshape, bingo!
|
||||
ConstantOp constantOp =
|
||||
mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
|
||||
// Look through the input of the current reshape.
|
||||
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
|
||||
reshape.getOperand()->getDefiningOp());
|
||||
|
||||
// If the input is defined by another constant, bingo!
|
||||
if (!constantOp)
|
||||
return matchFailure();
|
||||
|
||||
auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
|
||||
auto reshapeType = reshape.getType().cast<ToyArrayType>();
|
||||
if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
|
||||
// FIXME Check matching of element count!
|
||||
|
@ -93,9 +85,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
||||
auto newConstant = rewriter.create<ConstantOp>(
|
||||
constantOp.getLoc(), reshapeType.getShape(), newAttr);
|
||||
rewriter.replaceOp(op, {newConstant});
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
|
||||
// Broadcast
|
||||
|
@ -106,9 +97,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
|
||||
reshapeType.getElementType());
|
||||
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
|
||||
auto newConstant = rewriter.create<ConstantOp>(
|
||||
constantOp.getLoc(), reshapeType.getShape(), newAttr);
|
||||
rewriter.replaceOp(op, {newConstant});
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else {
|
||||
llvm_unreachable("Unsupported Constant format");
|
||||
}
|
||||
|
@ -117,43 +107,35 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
|
|||
};
|
||||
|
||||
/// Fold reshape(reshape(x)) -> reshape(x)
|
||||
struct SimplifyReshapeReshape : public mlir::RewritePattern {
|
||||
SimplifyReshapeReshape(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
// look through the input to the current reshape
|
||||
mlir::Value *reshapeInput = reshape.getOperand();
|
||||
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
|
||||
// Look through the input of the current reshape.
|
||||
mlir::Value *reshapeInput = op.getOperand();
|
||||
|
||||
// If the input is defined by another reshape, bingo!
|
||||
ReshapeOp reshapeInputOp =
|
||||
mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
|
||||
if (!reshapeInputOp)
|
||||
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {reshapeInputOp});
|
||||
rewriter.replaceOp(op, {reshapeInput});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(x)) -> x, when input type matches output type
|
||||
struct SimplifyNullReshape : public mlir::RewritePattern {
|
||||
SimplifyNullReshape(mlir::MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
|
||||
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
|
||||
if (op.getOperand()->getType() != op.getType())
|
||||
return matchFailure();
|
||||
rewriter.replaceOp(reshape, {reshape.getOperand()});
|
||||
rewriter.replaceOp(op, {op.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -176,17 +158,14 @@ void ReshapeOp::getCanonicalizationPatterns(
|
|||
namespace {
|
||||
|
||||
/// Fold type.cast(x) -> x, when input type matches output type
|
||||
struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
|
||||
SimplifyIdentityTypeCast(mlir::MLIRContext *context)
|
||||
: RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
|
||||
context) {}
|
||||
struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
|
||||
using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
matchAndRewrite(TypeCastOp typeCast,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
|
||||
auto resTy = typeCast.getResult()->getType();
|
||||
auto *candidateOp = op;
|
||||
auto resTy = typeCast.getType();
|
||||
auto *candidateOp = typeCast.getOperation();
|
||||
while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
|
||||
if (resTy == candidateOp->getOperand(0)->getType()) {
|
||||
rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
|
||||
|
|
|
@ -205,6 +205,53 @@ protected:
|
|||
llvm::SmallVector<OperationName, 2> generatedOps;
|
||||
};
|
||||
|
||||
/// OpRewritePattern is a wrapper around RewritePattern that allows for
|
||||
/// matching and rewriting against an instance of a derived operation class as
|
||||
/// opposed to a raw Operation.
|
||||
template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching.
|
||||
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
|
||||
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
|
||||
|
||||
/// Wrappers around the RewritePattern methods that pass the derived op type.
|
||||
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const final {
|
||||
rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
|
||||
}
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
|
||||
rewrite(llvm::cast<SourceOp>(op), rewriter);
|
||||
}
|
||||
PatternMatchResult match(Operation *op) const final {
|
||||
return match(llvm::cast<SourceOp>(op));
|
||||
}
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
|
||||
}
|
||||
|
||||
/// Rewrite and Match methods that operate on the SourceOp type. These must be
|
||||
/// overridden by the derived pattern class.
|
||||
virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const {
|
||||
rewrite(op, rewriter);
|
||||
}
|
||||
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
|
||||
llvm_unreachable("must override matchAndRewrite or a rewrite method");
|
||||
}
|
||||
virtual PatternMatchResult match(SourceOp op) const {
|
||||
llvm_unreachable("must override match or matchAndRewrite");
|
||||
}
|
||||
virtual PatternMatchResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (auto matchResult = match(op)) {
|
||||
rewrite(op, std::move(*matchResult), rewriter);
|
||||
return matchSuccess();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternRewriter class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -654,24 +654,21 @@ void mlir::canonicalizeMapAndOperands(
|
|||
namespace {
|
||||
/// Simplify AffineApply operations.
|
||||
///
|
||||
struct SimplifyAffineApply : public RewritePattern {
|
||||
SimplifyAffineApply(MLIRContext *context)
|
||||
: RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
|
||||
struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
|
||||
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(AffineApplyOp apply,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto apply = cast<AffineApplyOp>(op);
|
||||
auto map = apply.getAffineMap();
|
||||
|
||||
AffineMap oldMap = map;
|
||||
SmallVector<Value *, 8> resultOperands(apply.getOperands());
|
||||
composeAffineMapAndOperands(&map, &resultOperands);
|
||||
if (map != oldMap) {
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, map, resultOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
if (map == oldMap)
|
||||
return matchFailure();
|
||||
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -1002,14 +999,11 @@ void AffineForOp::print(OpAsmPrinter *p) {
|
|||
|
||||
namespace {
|
||||
/// This is a pattern to fold constant loop bounds.
|
||||
struct AffineForLoopBoundFolder : public RewritePattern {
|
||||
/// The rootOpName is the name of the root operation to match against.
|
||||
AffineForLoopBoundFolder(MLIRContext *context)
|
||||
: RewritePattern(AffineForOp::getOperationName(), 1, context) {}
|
||||
struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
|
||||
using OpRewritePattern<AffineForOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(AffineForOp forOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<AffineForOp>(op);
|
||||
auto foldLowerOrUpperBound = [&forOp](bool lower) {
|
||||
// Check to see if each of the operands is the result of a constant. If
|
||||
// so, get the value. If not, ignore it.
|
||||
|
@ -1056,7 +1050,7 @@ struct AffineForLoopBoundFolder : public RewritePattern {
|
|||
// If any of the bounds were folded we return success.
|
||||
if (!folded)
|
||||
return matchFailure();
|
||||
rewriter.updatedRootInPlace(op);
|
||||
rewriter.updatedRootInPlace(forOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input,
|
|||
|
||||
namespace {
|
||||
|
||||
struct UniformDequantizePattern : public RewritePattern {
|
||||
UniformDequantizePattern(MLIRContext *context)
|
||||
: RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
|
||||
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
||||
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto dcastOp = cast<DequantizeCastOp>(op);
|
||||
Type inputType = dcastOp.arg()->getType();
|
||||
Type outputType = dcastOp.getResult()->getType();
|
||||
Type inputType = op.arg()->getType();
|
||||
Type outputType = op.getResult()->getType();
|
||||
|
||||
QuantizedType inputElementType =
|
||||
QuantizedType::getQuantizedElementType(inputType);
|
||||
|
@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern {
|
|||
return matchFailure();
|
||||
}
|
||||
|
||||
Value *dequantizedValue =
|
||||
emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
|
||||
Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
|
||||
if (!dequantizedValue) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
|
|||
|
||||
namespace {
|
||||
|
||||
struct UniformRealAddEwPattern : public RewritePattern {
|
||||
UniformRealAddEwPattern(MLIRContext *context)
|
||||
: RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
|
||||
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
|
||||
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(RealAddEwOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto addOp = cast<RealAddEwOp>(op);
|
||||
const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
|
||||
addOp.clamp_min(), addOp.clamp_max());
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct UniformRealMulEwPattern : public RewritePattern {
|
||||
UniformRealMulEwPattern(MLIRContext *context)
|
||||
: RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
|
||||
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
|
||||
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(RealMulEwOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto mulOp = cast<RealMulEwOp>(op);
|
||||
const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
|
||||
mulOp.clamp_min(), mulOp.clamp_max());
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
|
|
@ -38,26 +38,21 @@ namespace {
|
|||
|
||||
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
|
||||
/// value of x if the casts invert each other.
|
||||
class RemoveRedundantStorageCastsRewrite : public RewritePattern {
|
||||
class RemoveRedundantStorageCastsRewrite
|
||||
: public OpRewritePattern<StorageCastOp> {
|
||||
public:
|
||||
RemoveRedundantStorageCastsRewrite(MLIRContext *context)
|
||||
: RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
|
||||
using OpRewritePattern<StorageCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto scastOp = cast<StorageCastOp>(op);
|
||||
if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
|
||||
auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
|
||||
if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
PatternMatchResult matchAndRewrite(StorageCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
|
||||
return matchFailure();
|
||||
auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
|
||||
if (srcScastOp.arg()->getType() != op.getType())
|
||||
return matchFailure();
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto scastOp = cast<StorageCastOp>(op);
|
||||
auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
|
||||
rewriter.replaceOp(op, srcScastOp.arg());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -36,40 +36,35 @@ public:
|
|||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
class QuantizedConstRewrite : public RewritePattern {
|
||||
public:
|
||||
struct State : PatternState {
|
||||
QuantizedType quantizedElementType;
|
||||
Attribute value;
|
||||
};
|
||||
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
|
||||
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
|
||||
|
||||
QuantizedConstRewrite(MLIRContext *context)
|
||||
: RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override;
|
||||
void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
|
||||
PatternRewriter &rewriter) const override;
|
||||
PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
|
||||
/// quantized and the operand type is quantizable.
|
||||
PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
|
||||
State state;
|
||||
|
||||
PatternMatchResult
|
||||
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
||||
PatternRewriter &rewriter) const {
|
||||
Attribute value;
|
||||
|
||||
// Is the operand a constant?
|
||||
auto qbarrier = cast<QuantizeCastOp>(op);
|
||||
if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
|
||||
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Does the qbarrier convert to a quantized type. This will not be true
|
||||
// if a quantized type has not yet been chosen or if the cast to an equivalent
|
||||
// storage type is not supported.
|
||||
Type qbarrierResultType = qbarrier.getResult()->getType();
|
||||
state.quantizedElementType =
|
||||
QuantizedType quantizedElementType =
|
||||
QuantizedType::getQuantizedElementType(qbarrierResultType);
|
||||
if (!state.quantizedElementType) {
|
||||
if (!quantizedElementType) {
|
||||
return matchFailure();
|
||||
}
|
||||
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
|
||||
|
@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
|
|||
// Is the operand type compatible with the expressed type of the quantized
|
||||
// type? This will not be true if the qbarrier is superfluous (converts
|
||||
// from and to a quantized type).
|
||||
if (!state.quantizedElementType.isCompatibleExpressedType(
|
||||
if (!quantizedElementType.isCompatibleExpressedType(
|
||||
qbarrier.arg()->getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Is the constant value a type expressed in a way that we support?
|
||||
if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
|
||||
!state.value.isa<DenseElementsAttr>() &&
|
||||
!state.value.isa<SparseElementsAttr>()) {
|
||||
if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
|
||||
!value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
return matchSuccess(llvm::make_unique<State>(std::move(state)));
|
||||
}
|
||||
|
||||
void QuantizedConstRewrite::rewrite(Operation *op,
|
||||
std::unique_ptr<PatternState> baseState,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto state = static_cast<State *>(baseState.get());
|
||||
|
||||
Type newConstValueType;
|
||||
Attribute newConstValue = quantizeAttr(
|
||||
state->value, state->quantizedElementType, newConstValueType);
|
||||
auto newConstValue =
|
||||
quantizeAttr(value, quantizedElementType, newConstValueType);
|
||||
if (!newConstValue) {
|
||||
return;
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
auto *origConstOp = op->getOperand(0);
|
||||
// When creating the new const op, use a fused location that combines the
|
||||
// original const and the qbarrier that led to the quantization.
|
||||
auto fusedLoc =
|
||||
FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
|
||||
rewriter.getContext());
|
||||
auto fusedLoc = FusedLoc::get(
|
||||
{qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
|
||||
rewriter.getContext());
|
||||
auto newConstOp =
|
||||
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
||||
{origConstOp}, op, *op->result_type_begin(), newConstOp);
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
|
||||
qbarrier.getType(), newConstOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void ConvertConstPass::runOnFunction() {
|
||||
|
|
|
@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) {
|
|||
|
||||
namespace {
|
||||
/// Fold constant dimensions into an alloc operation.
|
||||
struct SimplifyAllocConst : public RewritePattern {
|
||||
SimplifyAllocConst(MLIRContext *context)
|
||||
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto alloc = cast<AllocOp>(op);
|
||||
struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
||||
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check to see if any dimensions operands are constants. If so, we can
|
||||
// substitute and drop them.
|
||||
for (auto *operand : alloc.getOperands())
|
||||
if (matchPattern(operand, m_ConstantIndex()))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return matchFailure();
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto allocOp = cast<AllocOp>(op);
|
||||
auto memrefType = allocOp.getType();
|
||||
auto memrefType = alloc.getType();
|
||||
|
||||
// Ok, we have one or more constant operands. Collect the non-constant ones
|
||||
// and keep track of the resultant memref type to build.
|
||||
|
@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern {
|
|||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp();
|
||||
auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
|
@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern {
|
|||
} else {
|
||||
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||
newShapeConstants.push_back(-1);
|
||||
newOperands.push_back(allocOp.getOperand(dynamicDimPos));
|
||||
newOperands.push_back(alloc.getOperand(dynamicDimPos));
|
||||
}
|
||||
dynamicDimPos++;
|
||||
}
|
||||
|
@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern {
|
|||
|
||||
// Create and insert the alloc op for the new memref.
|
||||
auto newAlloc =
|
||||
rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands);
|
||||
rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
|
||||
// Insert a cast so we have the same type as the old alloc.
|
||||
auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc,
|
||||
allocOp.getType());
|
||||
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
|
||||
alloc.getType());
|
||||
|
||||
rewriter.replaceOp(op, {resultCast}, droppedOperands);
|
||||
rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold alloc operations with no uses. Alloc has side effects on the heap,
|
||||
/// but can still be deleted if it has zero uses.
|
||||
struct SimplifyDeadAlloc : public RewritePattern {
|
||||
SimplifyDeadAlloc(MLIRContext *context)
|
||||
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
|
||||
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
|
||||
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if the alloc'ed value has any uses.
|
||||
auto alloc = cast<AllocOp>(op);
|
||||
if (!alloc.use_empty())
|
||||
return matchFailure();
|
||||
|
||||
// If it doesn't, we can eliminate it.
|
||||
op->erase();
|
||||
alloc.erase();
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
namespace {
|
||||
/// Fold indirect calls that have a constant function as the callee operand.
|
||||
struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
|
||||
SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
|
||||
: RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
|
||||
struct SimplifyIndirectCallWithKnownCallee
|
||||
: public OpRewritePattern<CallIndirectOp> {
|
||||
using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto indirectCall = cast<CallIndirectOp>(op);
|
||||
|
||||
// Check that the callee is a constant callee.
|
||||
FunctionAttr calledFn;
|
||||
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
|
||||
return matchFailure();
|
||||
|
||||
// Replace with a direct call.
|
||||
SmallVector<Type, 8> callResults(op->getResultTypes());
|
||||
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
|
||||
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
|
||||
callOperands);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
|
||||
callResults, callOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -964,14 +956,11 @@ namespace {
|
|||
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
|
||||
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
|
||||
///
|
||||
struct SimplifyConstCondBranchPred : public RewritePattern {
|
||||
SimplifyConstCondBranchPred(MLIRContext *context)
|
||||
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
|
||||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto condbr = cast<CondBranchOp>(op);
|
||||
|
||||
// Check that the condition is a constant.
|
||||
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
|
||||
return matchFailure();
|
||||
|
@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
|
|||
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
|||
namespace {
|
||||
/// Fold Dealloc operations that are deallocating an AllocOp that is only used
|
||||
/// by other Dealloc operations.
|
||||
struct SimplifyDeadDealloc : public RewritePattern {
|
||||
SimplifyDeadDealloc(MLIRContext *context)
|
||||
: RewritePattern(DeallocOp::getOperationName(), 1, context) {}
|
||||
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
||||
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(DeallocOp dealloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dealloc = cast<DeallocOp>(op);
|
||||
|
||||
// Check that the memref operand's defining operation is an AllocOp.
|
||||
Value *memref = dealloc.memref();
|
||||
Operation *defOp = memref->getDefiningOp();
|
||||
if (!isa_and_nonnull<AllocOp>(defOp))
|
||||
if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
|
||||
return matchFailure();
|
||||
|
||||
// Check that all of the uses of the AllocOp are other DeallocOps.
|
||||
|
@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern {
|
|||
return matchFailure();
|
||||
|
||||
// Erase the dealloc operation.
|
||||
op->erase();
|
||||
rewriter.replaceOp(dealloc, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue