forked from OSchip/llvm-project
[mlir][vector][bufferize] Fix transfer_write dropping mask operand
Differential Revision: https://reviews.llvm.org/D129253
This commit is contained in:
parent
6106a767b7
commit
a28ce1a42b
|
@ -106,7 +106,7 @@ struct TransferWriteOpInterface
|
|||
rewriter.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
|
||||
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
|
||||
writeOp.getInBoundsAttr());
|
||||
writeOp.getMask(), writeOp.getInBoundsAttr());
|
||||
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
|
||||
|
||||
return success();
|
||||
|
|
|
@ -15,16 +15,17 @@ func.func @transfer_read(%t: tensor<?x?xf32>, %o1: index,
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_write(
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>)
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>, %[[mask:.*]]: vector<5x6xi1>)
|
||||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref<?x?xf32>
|
||||
// CHECK: memref.copy %[[m]], %[[alloc]]
|
||||
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32>
|
||||
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]], %[[mask]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32>
|
||||
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref<?x?xf32>
|
||||
// CHECK: return %[[r]]
|
||||
func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
|
||||
%o2: index, %vec: vector<5x6xf32>) -> tensor<?x?xf32> {
|
||||
%0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]}
|
||||
%o2: index, %vec: vector<5x6xf32>,
|
||||
%mask: vector<5x6xi1>) -> tensor<?x?xf32> {
|
||||
%0 = vector.transfer_write %vec, %t[%o1, %o2], %mask {in_bounds = [true, false]}
|
||||
: vector<5x6xf32>, tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue