[mlir][vector] Move transferOp on tensor opt to folder/canonicalization

Move the existing optimization for transfer op on tensor to folder and
canonicalization. This handles the write after write case and read after write
and also add write after read case.

Differential Revision: https://reviews.llvm.org/D100597
This commit is contained in:
thomasraoux 2021-04-15 13:43:44 -07:00
parent 517c3aee4d
commit 3fc0fbefc8
8 changed files with 295 additions and 121 deletions

View File

@ -1421,6 +1421,7 @@ def Vector_TransferWriteOp :
];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Vector_LoadOp : Vector_Op<"load"> {

View File

@ -28,6 +28,11 @@ class Value;
class VectorType;
class VectorTransferOpInterface;
namespace vector {
class TransferWriteOp;
class TransferReadOp;
} // namespace vector
/// Return the number of elements of basis, `0` if empty.
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
@ -177,6 +182,16 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);
/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
bool checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read);
/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
bool checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite);
namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a

View File

@ -2512,7 +2512,35 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
return success();
}
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
/// : tensor<4x4xf32>, vector<1x4xf32>
/// ```
/// -> Folds into
/// ```
/// %v0
/// ```
static Value foldRAW(TransferReadOp readOp) {
if (!readOp.getShapedType().isa<RankedTensorType>())
return {};
auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueRAW(defWrite, readOp))
return defWrite.vector();
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(readOp.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
return {};
}
OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
if (Value vec = foldRAW(*this))
return vec;
/// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferInBoundsAttribute(*this)))
return getResult();
@ -2724,10 +2752,47 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
return success();
}
static bool checkSameValueWAR(vector::TransferReadOp read,
vector::TransferWriteOp write) {
return read.source() == write.source() && read.indices() == write.indices() &&
read.permutation_map() == write.permutation_map() &&
read.getVectorType() == write.getVectorType() && !read.mask() &&
!write.mask();
}
/// Fold transfer_write write after read:
/// ```
/// %t0 = ...
/// %v = vector.transfer_read %t0[%c0...] :
/// tensor<static_sizesxf32>, vector<static_sizesxf32>
/// %t1 = vector.transfer_write %v, %t0[%c0...] :
/// vector<static_sizesxf32>, tensor<static_sizesxf32>
/// ```
///
/// into:
///
/// ```
/// %t0
/// ```
static LogicalResult foldWAR(TransferWriteOp write,
SmallVectorImpl<OpFoldResult> &results) {
if (!write.source().getType().isa<RankedTensorType>())
return failure();
auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
if (!read)
return failure();
if (!checkSameValueWAR(read, write))
return failure();
results.push_back(read.source());
return success();
}
LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
if (succeeded(foldReadInitWrite(*this, operands, results)))
return success();
if (succeeded(foldWAR(*this, results)))
return success();
if (succeeded(foldTransferInBoundsAttribute(*this)))
return success();
return foldMemRefCast(*this);
@ -2745,6 +2810,67 @@ void TransferWriteOp::getEffects(
SideEffects::DefaultResource::get());
}
namespace {
/// Remove dead transfer write from the SSA chain so that it an be eliminated by
/// DCE
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// into:
///
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
/// any other uses.
class foldWAW final : public OpRewritePattern<TransferWriteOp> {
public:
using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!writeOp.getShapedType().isa<RankedTensorType>())
return failure();
vector::TransferWriteOp writeToModify = writeOp;
auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
writeToModify.sourceMutable().assign(defWrite.source());
return success();
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(writeOp.getOperation())))
break;
// If the previous write op doesn't have any other use we an safely look
// at the previous store to see if it can be removed.
if (!defWrite->hasOneUse())
break;
writeToModify = defWrite;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
return failure();
}
};
} // namespace
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<foldWAW>(context);
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

View File

@ -34,34 +34,13 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
return op;
}
/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
static bool transferEncompasses(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
return !defWrite.hasOutOfBoundsDim() &&
defWrite.indices() == read.indices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.permutation_map() == read.permutation_map();
}
/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
static bool transferEncompasses(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
return priorWrite.indices() == write.indices() &&
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}
namespace {
class TransferOptimization {
public:
TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
void deadStoreOp(vector::TransferWriteOp);
void deadStoreOpTensor(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp);
void storeToLoadForwardingTensor(vector::TransferReadOp);
void removeDeadOp() {
for (Operation *op : opToErase)
op->erase();
@ -120,7 +99,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
if (transferEncompasses(nextWrite, write) &&
if (checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@ -192,8 +171,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
continue;
if (dominators.dominates(write, read) &&
transferEncompasses(write, read)) {
if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
else
@ -231,44 +209,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
/// If it has no more uses it becomes dead.
void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(write, defWrite)) {
write.sourceMutable().assign(defWrite.source());
if (defWrite->use_empty())
opToErase.push_back(defWrite.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(write.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
/// Walk up the SSA links, if any write fully match the written vector we can
/// replace the read by the vector. The read becomes dead and can be removed.
void TransferOptimization::storeToLoadForwardingTensor(
vector::TransferReadOp read) {
auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(defWrite, read)) {
read.replaceAllUsesWith(defWrite.vector());
opToErase.push_back(read.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
} // namespace
void mlir::vector::transferOpflowOpt(FuncOp func) {
@ -278,15 +218,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) {
func.walk([&](vector::TransferReadOp read) {
if (read.getShapedType().isa<MemRefType>())
opt.storeToLoadForwarding(read);
else
opt.storeToLoadForwardingTensor(read);
});
opt.removeDeadOp();
func.walk([&](vector::TransferWriteOp write) {
if (write.getShapedType().isa<MemRefType>())
opt.deadStoreOp(write);
else
opt.deadStoreOpTensor(write);
});
opt.removeDeadOp();
}

View File

@ -354,3 +354,19 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
return false;
return isDisjointTransferIndices(transferA, transferB);
}
bool mlir::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() &&
defWrite.indices() == read.indices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.permutation_map() == read.permutation_map();
}
bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
return priorWrite.indices() == write.indices() &&
priorWrite.mask() == write.mask() &&
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}

View File

@ -799,3 +799,136 @@ func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>)
// CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]]
return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>
}
// -----
// CHECK-LABEL: func @store_after_load_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>)
// CHECK-NOT: vector.transfer_read
// CHECK-NOT: vector.transfer_write
// CHECK: return %[[ARG]] : tensor<4x4xf32>
func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
tensor<4x4xf32>, vector<1x4xf32>
%w0 = vector.transfer_write %0, %arg0[%c1, %c0] :
vector<1x4xf32>, tensor<4x4xf32>
return %w0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @store_after_load_tensor_negative
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: return
func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
tensor<4x4xf32>, vector<1x4xf32>
%w0 = vector.transfer_write %0, %arg0[%c0, %c0] :
vector<1x4xf32>, tensor<4x4xf32>
return %w0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @store_to_load_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<1x4xf32>, %[[V1:.*]]: vector<1x4xf32>)
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: return %[[V0]] : vector<1x4xf32>
func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>) -> vector<1x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v1, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// -----
// CHECK-LABEL: func @store_to_load_negative_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: %[[V:.*]] = vector.transfer_read
// CHECK: return %[[V]] : vector<1x4xf32>
func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%i, %i] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// -----
// CHECK-LABEL: func @dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-NOT: vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
// CHECK: vector.transfer_write {{.*}}, {{.*}}[%[[C2]], %[[C0]]
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w2 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @dead_store_tensor_negative
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%i, %i], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = addf %0, %0 : vector<1x4xf32>
%w2 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w2 : tensor<4x4xf32>
}

View File

@ -112,11 +112,11 @@ func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>) ->
tensor<4x4xf32> {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
%r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
%r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
return %r: tensor<4x4xf32>
}

View File

@ -184,56 +184,3 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
return
}
// CHECK-LABEL: func @forward_dead_store_tensor
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}
// CHECK-LABEL: func @forward_dead_store_negative_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %i] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}