[mlir] Remove 'valuesToRemoveIfDead' from PatternRewriter API

Summary:
Remove 'valuesToRemoveIfDead' from PatternRewriter API. The removal
functionality wasn't implemented and we decided [1] not to implement it in
favor of having more powerful DCE approaches.

[1] https://github.com/tensorflow/mlir/pull/212

Reviewers: rriddle, bondhugula

Reviewed By: rriddle

Subscribers: liufengdb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72545
This commit is contained in:
Diego Caballero 2020-01-27 13:13:20 -08:00
parent 27f2e9ab1c
commit 6fb3d59746
11 changed files with 39 additions and 74 deletions

View File

@ -48,7 +48,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
return matchFailure();
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
}
};

View File

@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
return matchFailure();
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
}
};

View File

@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
return matchFailure();
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
}
};

View File

@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
return matchFailure();
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
}
};

View File

@ -71,7 +71,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
return matchFailure();
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
}
};

View File

@ -318,33 +318,15 @@ public:
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values. In addition to replacing and removing the specified operation,
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
virtual void replaceOp(Operation *op, ValueRange newValues,
ValueRange valuesToRemoveIfDead);
void replaceOp(Operation *op, ValueRange newValues) {
replaceOp(op, newValues, llvm::None);
}
/// values.
virtual void replaceOp(Operation *op, ValueRange newValues);
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(Operation *op, Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
}
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op,
Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
valuesToRemoveIfDead);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
}
/// This method erases an operation that is known to have no uses.
@ -405,10 +387,9 @@ protected:
virtual void notifyOperationRemoved(Operation *op) {}
private:
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
ValueRange valuesToRemoveIfDead);
/// 'op' and 'newOp' are known to have the same number of results, replace the
/// uses of op with uses of newOp.
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
//===----------------------------------------------------------------------===//

View File

@ -332,8 +332,7 @@ public:
//===--------------------------------------------------------------------===//
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues,
ValueRange valuesToRemoveIfDead) override;
void replaceOp(Operation *op, ValueRange newValues) override;
using PatternRewriter::replaceOp;
/// PatternRewriter hook for erasing a dead operation. The uses of this

View File

@ -90,8 +90,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
qbarrier.getType(), newConstOp);
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
newConstOp);
return matchSuccess();
}

View File

@ -328,7 +328,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
SmallVector<int64_t, 4> newShapeConstants;
newShapeConstants.reserve(memrefType.getRank());
SmallVector<Value, 4> newOperands;
SmallVector<Value, 4> droppedOperands;
unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@ -342,8 +341,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
// Record to check for zero uses later below.
droppedOperands.push_back(constantIndexOp);
} else {
// Dynamic shape dimension not folded; copy operand from old memref.
newShapeConstants.push_back(-1);
@ -366,7 +363,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
alloc.getType());
rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
rewriter.replaceOp(alloc, {resultCast});
return matchSuccess();
}
};
@ -2447,7 +2444,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
return matchFailure();
SmallVector<Value, 4> newOperands;
SmallVector<Value, 4> droppedOperands;
// Fold dynamic offset operand if it is produced by a constant.
auto dynamicOffset = viewOp.getDynamicOffset();
@ -2458,7 +2454,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic offset will be folded into the map.
newOffset = constantIndexOp.getValue();
droppedOperands.push_back(dynamicOffset);
} else {
// Unable to fold dynamic offset. Add it to 'newOperands' list.
newOperands.push_back(dynamicOffset);
@ -2483,8 +2478,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
// Record to check for zero uses later below.
droppedOperands.push_back(constantIndexOp);
} else {
// Dynamic shape dimension not folded; copy operand from old memref.
newShapeConstants.push_back(dimSize);
@ -2522,8 +2515,8 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
viewOp.getOperand(0), newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
newViewOp, viewOp.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
viewOp.getType());
return matchSuccess();
}
};
@ -2542,8 +2535,8 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
if (!allocOp)
return matchFailure();
rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
allocOperand, viewOp.operands());
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
viewOp.operands());
return matchSuccess();
}
};
@ -2839,8 +2832,8 @@ public:
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};
@ -2889,8 +2882,8 @@ public:
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};
@ -2941,8 +2934,8 @@ public:
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType());
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};

View File

@ -10,6 +10,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
using namespace mlir;
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
@ -72,12 +73,8 @@ PatternRewriter::~PatternRewriter() {
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values. In addition to replacing and removing the specified operation,
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those ops are dead, this will
/// remove them as well.
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
ValueRange valuesToRemoveIfDead) {
/// values.
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
@ -87,9 +84,6 @@ void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
notifyOperationRemoved(op);
op->erase();
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
// the notifyOperationRemoved hook in the process.
}
/// This method erases an operation that is known to have no uses. The uses of
@ -129,15 +123,15 @@ Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}
/// op and newOp are known to have the same number of results, replace the
/// 'op' and 'newOp' are known to have the same number of results, replace the
/// uses of op with uses of newOp
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) {
void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op");
if (op->getNumResults() == 1)
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead);
return replaceOp(op, newOp->getResult(0));
return replaceOp(op, newOp->getResults());
}
/// Move the blocks that belong to "region" before the given position in

View File

@ -554,8 +554,7 @@ struct ConversionPatternRewriterImpl {
TypeConverter::SignatureConversion &conversion);
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues,
ValueRange valuesToRemoveIfDead);
void replaceOp(Operation *op, ValueRange newValues);
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@ -757,8 +756,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
void ConversionPatternRewriterImpl::replaceOp(Operation *op,
ValueRange newValues,
ValueRange valuesToRemoveIfDead) {
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
// Create mappings for each of the new result values.
@ -838,11 +836,11 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
ConversionPatternRewriter::~ConversionPatternRewriter() {}
/// PatternRewriter hook for replacing the results of an operation.
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
ValueRange valuesToRemoveIfDead) {
void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
<< "\n");
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
impl->replaceOp(op, newValues);
}
/// PatternRewriter hook for erasing a dead operation. The uses of this
@ -852,7 +850,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName()
<< "\n");
SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None);
impl->replaceOp(op, nullRepls);
}
/// Apply a signature conversion to the entry block of the given region.