[mlir][vector] Fix bug in vector-transfer-full-partial-split

When splitting with linalg.copy, cannot write into the destination alloc directly. Instead, write into a subview of the alloc.

Differential Revision: https://reviews.llvm.org/D110512
This commit is contained in:
Matthias Springer 2021-09-27 17:13:11 +09:00
parent 683e506324
commit ffdf0a370d
2 changed files with 31 additions and 22 deletions

View File

@ -1835,9 +1835,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// Operates under a scoped context to build the intersection between the /// Operates under a scoped context to build the intersection between the
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/differences should be a proper std op. // TODO: view intersection/union/differences should be a proper std op.
static Value createSubViewIntersection(OpBuilder &b, static std::pair<Value, Value> createSubViewIntersection(
VectorTransferOpInterface xferOp, OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
Value alloc) {
ImplicitLocOpBuilder lb(xferOp.getLoc(), b); ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
int64_t memrefRank = xferOp.getShapedType().getRank(); int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews. // TODO: relax this precondition, will require rank-reducing subviews.
@ -1864,11 +1863,15 @@ static Value createSubViewIntersection(OpBuilder &b,
sizes.push_back(affineMin); sizes.push_back(affineMin);
}); });
SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range( SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
return lb.create<memref::SubViewOp>( SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
isaWrite ? alloc : xferOp.source(), indices, sizes, SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1))); auto copySrc = lb.create<memref::SubViewOp>(
isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
auto copyDest = lb.create<memref::SubViewOp>(
isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
return std::make_pair(copySrc, copyDest);
} }
/// Given an `xferOp` for which: /// Given an `xferOp` for which:
@ -1877,14 +1880,15 @@ static Value createSubViewIntersection(OpBuilder &b,
/// Produce IR resembling: /// Produce IR resembling:
/// ``` /// ```
/// %1:3 = scf.if (%inBounds) { /// %1:3 = scf.if (%inBounds) {
/// memref.cast %A: memref<A...> to compatibleMemRefType /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index /// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else { /// } else {
/// %2 = linalg.fill(%pad, %alloc) /// %2 = linalg.fill(%pad, %alloc)
/// %3 = subview %view [...][...][...] /// %3 = subview %view [...][...][...]
/// linalg.copy(%3, %alloc) /// %4 = subview %alloc [0, 0] [...] [...]
/// memref.cast %alloc: memref<B...> to compatibleMemRefType /// linalg.copy(%3, %4)
/// scf.yield %4, ... : compatibleMemRefType, index, index /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %5, ... : compatibleMemRefType, index, index
/// } /// }
/// ``` /// ```
/// Return the produced scf::IfOp. /// Return the produced scf::IfOp.
@ -1910,9 +1914,9 @@ createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
b.create<linalg::FillOp>(loc, xferOp.padding(), alloc); b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
// Take partial subview of memref which guarantees no dimension // Take partial subview of memref which guarantees no dimension
// overflows. // overflows.
Value memRefSubView = createSubViewIntersection( std::pair<Value, Value> copyArgs = createSubViewIntersection(
b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc); b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
b.create<linalg::CopyOp>(loc, memRefSubView, alloc); b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted = Value casted =
b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted}; scf::ValueVector viewAndIndices{casted};
@ -2030,7 +2034,8 @@ getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
/// %notInBounds = xor %inBounds, %true /// %notInBounds = xor %inBounds, %true
/// scf.if (%notInBounds) { /// scf.if (%notInBounds) {
/// %3 = subview %alloc [...][...][...] /// %3 = subview %alloc [...][...][...]
/// linalg.copy(%3, %view) /// %4 = subview %view [0, 0][...][...]
/// linalg.copy(%3, %4)
/// } /// }
/// ``` /// ```
static void createFullPartialLinalgCopy(OpBuilder &b, static void createFullPartialLinalgCopy(OpBuilder &b,
@ -2040,9 +2045,9 @@ static void createFullPartialLinalgCopy(OpBuilder &b,
auto notInBounds = auto notInBounds =
lb.create<XOrOp>(inBoundsCond, lb.create<ConstantIntOp>(true, 1)); lb.create<XOrOp>(inBoundsCond, lb.create<ConstantIntOp>(true, 1));
lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) { lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
Value memRefSubView = createSubViewIntersection( std::pair<Value, Value> copyArgs = createSubViewIntersection(
b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc); b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
b.create<linalg::CopyOp>(loc, memRefSubView, xferOp.source()); b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
b.create<scf::YieldOp>(loc, ValueRange{}); b.create<scf::YieldOp>(loc, ValueRange{});
}); });
} }

View File

@ -81,7 +81,8 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
// LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_stride_8x1]]> // LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_stride_8x1]]>
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<4x8xf32> // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
// LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<?x?xf32, #{{.*}}>
// LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] :
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32> // LINALG-SAME: memref<4x8xf32> to memref<?x8xf32>
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@ -172,7 +173,8 @@ func @split_vector_transfer_read_strided_2d(
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_stride_1]]> // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_stride_1]]>
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<4x8xf32> // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
// LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<?x?xf32, #{{.*}}>
// LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] :
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]> // LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@ -276,8 +278,9 @@ func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %
// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] // LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] // LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP4]]> // LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP4]]>
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]]) // LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
// LINALG-SAME: : memref<?x?xf32, #[[MAP4]]>, memref<?x8xf32> // LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
// LINALG-SAME: : memref<?x?xf32, #[[MAP4]]>, memref<?x?xf32, #{{.*}}>
// LINALG: } // LINALG: }
// LINALG: return // LINALG: return
// LINALG: } // LINALG: }
@ -384,8 +387,9 @@ func @split_vector_transfer_write_strided_2d(
// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] // LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] // LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP5]]> // LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP5]]>
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]]) // LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
// LINALG-SAME: : memref<?x?xf32, #[[MAP5]]>, memref<7x8xf32, #[[MAP0]]> // LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
// LINALG-SAME: : memref<?x?xf32, #[[MAP5]]>, memref<?x?xf32, #[[MAP0]]>
// LINALG: } // LINALG: }
// LINALG: return // LINALG: return
// LINALG: } // LINALG: }