[mlir][vector] Don't duplicate transfer_read during vector distribution

Only apply the pattern if the transfer_read can be distributed for all
its uses.

Differential Revision: https://reviews.llvm.org/D133538
This commit is contained in:
Thomas Raoux 2022-09-08 22:57:54 +00:00
parent da8c9521ee
commit 06413618ea
2 changed files with 22 additions and 0 deletions

View File

@ -712,6 +712,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!operand) if (!operand)
return failure(); return failure();
auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
// Don't duplicate transfer_read ops when distributing.
if (!read.getResult().hasOneUse())
return failure();
unsigned operandIndex = operand->getOperandNumber(); unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex); Value distributedVal = warpOp.getResult(operandIndex);

View File

@ -650,3 +650,22 @@ func.func @lane_dependent_warp_propagate_read(
vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32> vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32>
return return
} }
// -----
// CHECK-PROP: func @dont_duplicate_read
func.func @dont_duplicate_read(
%laneid: index, %src: memref<1024xf32>) -> vector<1xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
// CHECK-PROP-NEXT: vector.transfer_read
// CHECK-PROP-NEXT: "blocking_use"
// CHECK-PROP-NEXT: vector.yield
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
%2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32>
"blocking_use"(%2) : (vector<32xf32>) -> ()
vector.yield %2 : vector<32xf32>
}
return %r : vector<1xf32>
}