[mlir][vector] Move transferOp on tensor opt to folder/canonicalization

Move the existing optimization for transfer op on tensor to folder and
canonicalization. This handles the write after write case and read after write
and also add write after read case.

Differential Revision: https://reviews.llvm.org/D100597
This commit is contained in:
thomasraoux 2021-04-15 13:43:44 -07:00
parent 517c3aee4d
commit 3fc0fbefc8
8 changed files with 295 additions and 121 deletions

View File

@ -1421,6 +1421,7 @@ def Vector_TransferWriteOp :
]; ];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def Vector_LoadOp : Vector_Op<"load"> { def Vector_LoadOp : Vector_Op<"load"> {

View File

@ -28,6 +28,11 @@ class Value;
class VectorType; class VectorType;
class VectorTransferOpInterface; class VectorTransferOpInterface;
namespace vector {
class TransferWriteOp;
class TransferReadOp;
} // namespace vector
/// Return the number of elements of basis, `0` if empty. /// Return the number of elements of basis, `0` if empty.
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis); int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
@ -177,6 +182,16 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB); VectorTransferOpInterface transferB);
/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
bool checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read);
/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
bool checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite);
namespace matcher { namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a /// Matches vector.transfer_read, vector.transfer_write and ops that return a

View File

@ -2512,7 +2512,35 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
return success(); return success();
} }
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
/// : tensor<4x4xf32>, vector<1x4xf32>
/// ```
/// -> Folds into
/// ```
/// %v0
/// ```
static Value foldRAW(TransferReadOp readOp) {
if (!readOp.getShapedType().isa<RankedTensorType>())
return {};
auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueRAW(defWrite, readOp))
return defWrite.vector();
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(readOp.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
return {};
}
OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
if (Value vec = foldRAW(*this))
return vec;
/// transfer_read(memrefcast) -> transfer_read /// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferInBoundsAttribute(*this))) if (succeeded(foldTransferInBoundsAttribute(*this)))
return getResult(); return getResult();
@ -2724,10 +2752,47 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
return success(); return success();
} }
static bool checkSameValueWAR(vector::TransferReadOp read,
vector::TransferWriteOp write) {
return read.source() == write.source() && read.indices() == write.indices() &&
read.permutation_map() == write.permutation_map() &&
read.getVectorType() == write.getVectorType() && !read.mask() &&
!write.mask();
}
/// Fold transfer_write write after read:
/// ```
/// %t0 = ...
/// %v = vector.transfer_read %t0[%c0...] :
/// tensor<static_sizesxf32>, vector<static_sizesxf32>
/// %t1 = vector.transfer_write %v, %t0[%c0...] :
/// vector<static_sizesxf32>, tensor<static_sizesxf32>
/// ```
///
/// into:
///
/// ```
/// %t0
/// ```
static LogicalResult foldWAR(TransferWriteOp write,
SmallVectorImpl<OpFoldResult> &results) {
if (!write.source().getType().isa<RankedTensorType>())
return failure();
auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
if (!read)
return failure();
if (!checkSameValueWAR(read, write))
return failure();
results.push_back(read.source());
return success();
}
LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) { SmallVectorImpl<OpFoldResult> &results) {
if (succeeded(foldReadInitWrite(*this, operands, results))) if (succeeded(foldReadInitWrite(*this, operands, results)))
return success(); return success();
if (succeeded(foldWAR(*this, results)))
return success();
if (succeeded(foldTransferInBoundsAttribute(*this))) if (succeeded(foldTransferInBoundsAttribute(*this)))
return success(); return success();
return foldMemRefCast(*this); return foldMemRefCast(*this);
@ -2745,6 +2810,67 @@ void TransferWriteOp::getEffects(
SideEffects::DefaultResource::get()); SideEffects::DefaultResource::get());
} }
namespace {
/// Remove dead transfer write from the SSA chain so that it an be eliminated by
/// DCE
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// into:
///
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
/// any other uses.
class foldWAW final : public OpRewritePattern<TransferWriteOp> {
public:
using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!writeOp.getShapedType().isa<RankedTensorType>())
return failure();
vector::TransferWriteOp writeToModify = writeOp;
auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
writeToModify.sourceMutable().assign(defWrite.source());
return success();
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(writeOp.getOperation())))
break;
// If the previous write op doesn't have any other use we an safely look
// at the previous store to see if it can be removed.
if (!defWrite->hasOneUse())
break;
writeToModify = defWrite;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
return failure();
}
};
} // namespace
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<foldWAW>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LoadOp // LoadOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -34,34 +34,13 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
return op; return op;
} }
/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
static bool transferEncompasses(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
return !defWrite.hasOutOfBoundsDim() &&
defWrite.indices() == read.indices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.permutation_map() == read.permutation_map();
}
/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
static bool transferEncompasses(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
return priorWrite.indices() == write.indices() &&
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}
namespace { namespace {
class TransferOptimization { class TransferOptimization {
public: public:
TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {} TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
void deadStoreOp(vector::TransferWriteOp); void deadStoreOp(vector::TransferWriteOp);
void deadStoreOpTensor(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp); void storeToLoadForwarding(vector::TransferReadOp);
void storeToLoadForwardingTensor(vector::TransferReadOp);
void removeDeadOp() { void removeDeadOp() {
for (Operation *op : opToErase) for (Operation *op : opToErase)
op->erase(); op->erase();
@ -120,7 +99,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue; continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store. // Check candidate that can override the store.
if (transferEncompasses(nextWrite, write) && if (checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) { postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr || if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite)) postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@ -192,8 +171,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(write.getOperation()), cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation()))) cast<VectorTransferOpInterface>(read.getOperation())))
continue; continue;
if (dominators.dominates(write, read) && if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
transferEncompasses(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write; lastwrite = write;
else else
@ -231,44 +209,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation()); opToErase.push_back(read.getOperation());
} }
/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
/// If it has no more uses it becomes dead.
void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(write, defWrite)) {
write.sourceMutable().assign(defWrite.source());
if (defWrite->use_empty())
opToErase.push_back(defWrite.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(write.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
/// Walk up the SSA links, if any write fully match the written vector we can
/// replace the read by the vector. The read becomes dead and can be removed.
void TransferOptimization::storeToLoadForwardingTensor(
vector::TransferReadOp read) {
auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (transferEncompasses(defWrite, read)) {
read.replaceAllUsesWith(defWrite.vector());
opToErase.push_back(read.getOperation());
return;
}
if (!isDisjointTransferIndices(
cast<VectorTransferOpInterface>(defWrite.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
break;
defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
}
}
} // namespace } // namespace
void mlir::vector::transferOpflowOpt(FuncOp func) { void mlir::vector::transferOpflowOpt(FuncOp func) {
@ -278,15 +218,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) {
func.walk([&](vector::TransferReadOp read) { func.walk([&](vector::TransferReadOp read) {
if (read.getShapedType().isa<MemRefType>()) if (read.getShapedType().isa<MemRefType>())
opt.storeToLoadForwarding(read); opt.storeToLoadForwarding(read);
else
opt.storeToLoadForwardingTensor(read);
}); });
opt.removeDeadOp(); opt.removeDeadOp();
func.walk([&](vector::TransferWriteOp write) { func.walk([&](vector::TransferWriteOp write) {
if (write.getShapedType().isa<MemRefType>()) if (write.getShapedType().isa<MemRefType>())
opt.deadStoreOp(write); opt.deadStoreOp(write);
else
opt.deadStoreOpTensor(write);
}); });
opt.removeDeadOp(); opt.removeDeadOp();
} }

View File

@ -354,3 +354,19 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
return false; return false;
return isDisjointTransferIndices(transferA, transferB); return isDisjointTransferIndices(transferA, transferB);
} }
bool mlir::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() &&
defWrite.indices() == read.indices() &&
defWrite.getVectorType() == read.getVectorType() &&
defWrite.permutation_map() == read.permutation_map();
}
bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) {
return priorWrite.indices() == write.indices() &&
priorWrite.mask() == write.mask() &&
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}

View File

@ -799,3 +799,136 @@ func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>)
// CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]] // CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]]
return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32> return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>
} }
// -----
// CHECK-LABEL: func @store_after_load_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>)
// CHECK-NOT: vector.transfer_read
// CHECK-NOT: vector.transfer_write
// CHECK: return %[[ARG]] : tensor<4x4xf32>
func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
tensor<4x4xf32>, vector<1x4xf32>
%w0 = vector.transfer_write %0, %arg0[%c1, %c0] :
vector<1x4xf32>, tensor<4x4xf32>
return %w0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @store_after_load_tensor_negative
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: return
func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
tensor<4x4xf32>, vector<1x4xf32>
%w0 = vector.transfer_write %0, %arg0[%c0, %c0] :
vector<1x4xf32>, tensor<4x4xf32>
return %w0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @store_to_load_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<1x4xf32>, %[[V1:.*]]: vector<1x4xf32>)
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: return %[[V0]] : vector<1x4xf32>
func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>) -> vector<1x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v1, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// -----
// CHECK-LABEL: func @store_to_load_negative_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: %[[V:.*]] = vector.transfer_read
// CHECK: return %[[V]] : vector<1x4xf32>
func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%i, %i] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// -----
// CHECK-LABEL: func @dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-NOT: vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
// CHECK: vector.transfer_write {{.*}}, {{.*}}[%[[C2]], %[[C0]]
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w2 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @dead_store_tensor_negative
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w1[%i, %i], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = addf %0, %0 : vector<1x4xf32>
%w2 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w2 : tensor<4x4xf32>
}

View File

@ -112,11 +112,11 @@ func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> // CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) -> func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>) ->
tensor<4x4xf32> { tensor<4x4xf32> {
%c0 = constant 0 : index %c0 = constant 0 : index
%cf0 = constant 0.0 : f32 %cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32> %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
%r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32> %r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
return %r: tensor<4x4xf32> return %r: tensor<4x4xf32>
} }

View File

@ -184,56 +184,3 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
return return
} }
// CHECK-LABEL: func @forward_dead_store_tensor
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}
// CHECK-LABEL: func @forward_dead_store_negative_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: scf.for
// CHECK: }
// CHECK: %[[VTW:.*]] = vector.transfer_write
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = constant 1 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg1[%c1, %i] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
%0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
tensor<4x4xf32>, vector<1x4xf32>
%x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
%w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
vector<1x4xf32>, tensor<4x4xf32>
return %w1 : tensor<4x4xf32>
}