forked from OSchip/llvm-project
[mlir][vector] Don't duplicate transfer_read during vector distribution
Only apply the pattern if the transfer_read can be distributed for all its uses. Differential Revision: https://reviews.llvm.org/D133538
This commit is contained in:
parent
da8c9521ee
commit
06413618ea
|
@ -712,6 +712,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||||
if (!operand)
|
if (!operand)
|
||||||
return failure();
|
return failure();
|
||||||
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
|
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
|
||||||
|
// Don't duplicate transfer_read ops when distributing.
|
||||||
|
if (!read.getResult().hasOneUse())
|
||||||
|
return failure();
|
||||||
unsigned operandIndex = operand->getOperandNumber();
|
unsigned operandIndex = operand->getOperandNumber();
|
||||||
Value distributedVal = warpOp.getResult(operandIndex);
|
Value distributedVal = warpOp.getResult(operandIndex);
|
||||||
|
|
||||||
|
|
|
@ -650,3 +650,22 @@ func.func @lane_dependent_warp_propagate_read(
|
||||||
vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32>
|
vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-PROP: func @dont_duplicate_read
|
||||||
|
func.func @dont_duplicate_read(
|
||||||
|
%laneid: index, %src: memref<1024xf32>) -> vector<1xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%cst = arith.constant 0.000000e+00 : f32
|
||||||
|
// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
|
||||||
|
// CHECK-PROP-NEXT: vector.transfer_read
|
||||||
|
// CHECK-PROP-NEXT: "blocking_use"
|
||||||
|
// CHECK-PROP-NEXT: vector.yield
|
||||||
|
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
|
||||||
|
%2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32>
|
||||||
|
"blocking_use"(%2) : (vector<32xf32>) -> ()
|
||||||
|
vector.yield %2 : vector<32xf32>
|
||||||
|
}
|
||||||
|
return %r : vector<1xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue