forked from OSchip/llvm-project
[mlir] Remove 'valuesToRemoveIfDead' from PatternRewriter API
Summary: Remove 'valuesToRemoveIfDead' from PatternRewriter API. The removal functionality wasn't implemented and we decided [1] not to implement it in favor of having more powerful DCE approaches. [1] https://github.com/tensorflow/mlir/pull/212 Reviewers: rriddle, bondhugula Reviewed By: rriddle Subscribers: liufengdb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72545
This commit is contained in:
parent
27f2e9ab1c
commit
6fb3d59746
|
@ -48,7 +48,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -71,7 +71,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -318,33 +318,15 @@ public:
|
|||
|
||||
/// This method performs the final replacement for a pattern, where the
|
||||
/// results of the operation are updated to use the specified list of SSA
|
||||
/// values. In addition to replacing and removing the specified operation,
|
||||
/// clients can specify a list of other nodes that this replacement may make
|
||||
/// (perhaps transitively) dead. If any of those values are dead, this will
|
||||
/// remove them as well.
|
||||
virtual void replaceOp(Operation *op, ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead);
|
||||
void replaceOp(Operation *op, ValueRange newValues) {
|
||||
replaceOp(op, newValues, llvm::None);
|
||||
}
|
||||
/// values.
|
||||
virtual void replaceOp(Operation *op, ValueRange newValues);
|
||||
|
||||
/// Replaces the result op with a new op that is created without verification.
|
||||
/// The result values of the two ops must be the same types.
|
||||
template <typename OpTy, typename... Args>
|
||||
void replaceOpWithNewOp(Operation *op, Args &&... args) {
|
||||
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
|
||||
}
|
||||
|
||||
/// Replaces the result op with a new op that is created without verification.
|
||||
/// The result values of the two ops must be the same types. This allows
|
||||
/// specifying a list of ops that may be removed if dead.
|
||||
template <typename OpTy, typename... Args>
|
||||
void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op,
|
||||
Args &&... args) {
|
||||
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
|
||||
valuesToRemoveIfDead);
|
||||
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
|
||||
}
|
||||
|
||||
/// This method erases an operation that is known to have no uses.
|
||||
|
@ -405,10 +387,9 @@ protected:
|
|||
virtual void notifyOperationRemoved(Operation *op) {}
|
||||
|
||||
private:
|
||||
/// op and newOp are known to have the same number of results, replace the
|
||||
/// uses of op with uses of newOp
|
||||
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
|
||||
ValueRange valuesToRemoveIfDead);
|
||||
/// 'op' and 'newOp' are known to have the same number of results, replace the
|
||||
/// uses of op with uses of newOp.
|
||||
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -332,8 +332,7 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void replaceOp(Operation *op, ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead) override;
|
||||
void replaceOp(Operation *op, ValueRange newValues) override;
|
||||
using PatternRewriter::replaceOp;
|
||||
|
||||
/// PatternRewriter hook for erasing a dead operation. The uses of this
|
||||
|
|
|
@ -90,8 +90,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
rewriter.getContext());
|
||||
auto newConstOp =
|
||||
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
|
||||
qbarrier.getType(), newConstOp);
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
|
||||
newConstOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
|
|
@ -328,7 +328,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
SmallVector<int64_t, 4> newShapeConstants;
|
||||
newShapeConstants.reserve(memrefType.getRank());
|
||||
SmallVector<Value, 4> newOperands;
|
||||
SmallVector<Value, 4> droppedOperands;
|
||||
|
||||
unsigned dynamicDimPos = 0;
|
||||
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
||||
|
@ -342,8 +341,6 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
// Record to check for zero uses later below.
|
||||
droppedOperands.push_back(constantIndexOp);
|
||||
} else {
|
||||
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||
newShapeConstants.push_back(-1);
|
||||
|
@ -366,7 +363,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
|
||||
alloc.getType());
|
||||
|
||||
rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
|
||||
rewriter.replaceOp(alloc, {resultCast});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2447,7 +2444,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
return matchFailure();
|
||||
|
||||
SmallVector<Value, 4> newOperands;
|
||||
SmallVector<Value, 4> droppedOperands;
|
||||
|
||||
// Fold dynamic offset operand if it is produced by a constant.
|
||||
auto dynamicOffset = viewOp.getDynamicOffset();
|
||||
|
@ -2458,7 +2454,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic offset will be folded into the map.
|
||||
newOffset = constantIndexOp.getValue();
|
||||
droppedOperands.push_back(dynamicOffset);
|
||||
} else {
|
||||
// Unable to fold dynamic offset. Add it to 'newOperands' list.
|
||||
newOperands.push_back(dynamicOffset);
|
||||
|
@ -2483,8 +2478,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
// Record to check for zero uses later below.
|
||||
droppedOperands.push_back(constantIndexOp);
|
||||
} else {
|
||||
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||
newShapeConstants.push_back(dimSize);
|
||||
|
@ -2522,8 +2515,8 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
|
||||
viewOp.getOperand(0), newOperands);
|
||||
// Insert a cast so we have the same type as the old memref type.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
|
||||
newViewOp, viewOp.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
|
||||
viewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2542,8 +2535,8 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
|
|||
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
|
||||
if (!allocOp)
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
|
||||
allocOperand, viewOp.operands());
|
||||
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
|
||||
viewOp.operands());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2839,8 +2832,8 @@ public:
|
|||
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
|
||||
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2889,8 +2882,8 @@ public:
|
|||
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
|
||||
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2941,8 +2934,8 @@ public:
|
|||
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
|
||||
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
|
||||
|
@ -72,12 +73,8 @@ PatternRewriter::~PatternRewriter() {
|
|||
|
||||
/// This method performs the final replacement for a pattern, where the
|
||||
/// results of the operation are updated to use the specified list of SSA
|
||||
/// values. In addition to replacing and removing the specified operation,
|
||||
/// clients can specify a list of other nodes that this replacement may make
|
||||
/// (perhaps transitively) dead. If any of those ops are dead, this will
|
||||
/// remove them as well.
|
||||
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead) {
|
||||
/// values.
|
||||
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
|
||||
// Notify the rewriter subclass that we're about to replace this root.
|
||||
notifyRootReplaced(op);
|
||||
|
||||
|
@ -87,9 +84,6 @@ void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
|
|||
|
||||
notifyOperationRemoved(op);
|
||||
op->erase();
|
||||
|
||||
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
|
||||
// the notifyOperationRemoved hook in the process.
|
||||
}
|
||||
|
||||
/// This method erases an operation that is known to have no uses. The uses of
|
||||
|
@ -129,15 +123,15 @@ Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
|
|||
return block->splitBlock(before);
|
||||
}
|
||||
|
||||
/// op and newOp are known to have the same number of results, replace the
|
||||
/// 'op' and 'newOp' are known to have the same number of results, replace the
|
||||
/// uses of op with uses of newOp
|
||||
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
|
||||
Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) {
|
||||
void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
|
||||
Operation *newOp) {
|
||||
assert(op->getNumResults() == newOp->getNumResults() &&
|
||||
"replacement op doesn't match results of original op");
|
||||
if (op->getNumResults() == 1)
|
||||
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
|
||||
return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead);
|
||||
return replaceOp(op, newOp->getResult(0));
|
||||
return replaceOp(op, newOp->getResults());
|
||||
}
|
||||
|
||||
/// Move the blocks that belong to "region" before the given position in
|
||||
|
|
|
@ -554,8 +554,7 @@ struct ConversionPatternRewriterImpl {
|
|||
TypeConverter::SignatureConversion &conversion);
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void replaceOp(Operation *op, ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead);
|
||||
void replaceOp(Operation *op, ValueRange newValues);
|
||||
|
||||
/// Notifies that a block was split.
|
||||
void notifySplitBlock(Block *block, Block *continuation);
|
||||
|
@ -757,8 +756,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
|||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::replaceOp(Operation *op,
|
||||
ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead) {
|
||||
ValueRange newValues) {
|
||||
assert(newValues.size() == op->getNumResults());
|
||||
|
||||
// Create mappings for each of the new result values.
|
||||
|
@ -838,11 +836,11 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
|
|||
ConversionPatternRewriter::~ConversionPatternRewriter() {}
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
|
||||
ValueRange valuesToRemoveIfDead) {
|
||||
void ConversionPatternRewriter::replaceOp(Operation *op,
|
||||
ValueRange newValues) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
|
||||
<< "\n");
|
||||
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
|
||||
impl->replaceOp(op, newValues);
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for erasing a dead operation. The uses of this
|
||||
|
@ -852,7 +850,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
|
|||
LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName()
|
||||
<< "\n");
|
||||
SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
|
||||
impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None);
|
||||
impl->replaceOp(op, nullRepls);
|
||||
}
|
||||
|
||||
/// Apply a signature conversion to the entry block of the given region.
|
||||
|
|
Loading…
Reference in New Issue