[mlir][vector] Avoid hoisting alloca'ed temporary buffers across AutomaticAllocationScope

This revision avoids incorrect hoisting of alloca'd buffers across an AutomaticAllocationScope boundary.
In the more general case, we will probably need a ParallelScope-like interface.

Differential Revision: https://reviews.llvm.org/D118768
This commit is contained in:
Nicolas Vasilache 2022-02-02 05:21:02 -05:00
parent 83b74544c6
commit 3c3810e72e
4 changed files with 62 additions and 7 deletions

View File

@ -267,15 +267,22 @@ struct BufferAllocs {
Value maskBuffer;
};
// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
static Operation *getAutomaticAllocationScope(Operation *op) {
Operation *scope =
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
return scope;
}
/// Allocate temporary buffers for data (vector) and mask (if present).
/// TODO: Parallelism and threadlocal considerations.
template <typename OpTy>
static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
Location loc = xferOp.getLoc();
OpBuilder::InsertionGuard guard(b);
Operation *scope =
xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
Operation *scope = getAutomaticAllocationScope(xferOp);
assert(scope->getNumRegions() == 1 &&
"AutomaticAllocationScope with >1 regions");
b.setInsertionPointToStart(&scope->getRegion(0).front());
BufferAllocs result;

View File

@ -438,6 +438,14 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b,
});
}
// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
static Operation *getAutomaticAllocationScope(Operation *op) {
Operation *scope =
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
return scope;
}
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
///
@ -538,12 +546,14 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
// Top of the function `alloc` for transient storage.
Value alloc;
{
FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
RewriterBase::InsertionGuard guard(b);
b.setInsertionPointToStart(&funcOp.getRegion().front());
Operation *scope = getAutomaticAllocationScope(xferOp);
assert(scope->getNumRegions() == 1 &&
"AutomaticAllocationScope with >1 regions");
b.setInsertionPointToStart(&scope->getRegion(0).front());
auto shape = xferOp.getVectorType().getShape();
Type elementType = xferOp.getVectorType().getElementType();
alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
alloc = b.create<memref::AllocaOp>(scope->getLoc(),
MemRefType::get(shape, elementType),
ValueRange{}, b.getI64IntegerAttr(32));
}

View File

@ -481,3 +481,22 @@ func @transfer_write_strided(%A : vector<4xf32>, %B : memref<8x4xf32, affine_map
// CHECK-LABEL: transfer_write_strided(
// CHECK: scf.for
// CHECK: store
// -----
func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()
// CHECK-LABEL: transfer_read_within_async_execute
func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.token {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
// CHECK-NOT: alloca
// CHECK: async.execute
// CHECK: alloca
%token = async.execute {
%0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32>
call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> ()
async.yield
}
return %token : !async.token
}

View File

@ -393,3 +393,22 @@ func @split_vector_transfer_write_strided_2d(
// LINALG: }
// LINALG: return
// LINALG: }
// -----
func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()
// CHECK-LABEL: transfer_read_within_async_execute
func @transfer_read_within_async_execute(%A : memref<?x?xf32>) -> !async.token {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
// CHECK-NOT: alloca
// CHECK: async.execute
// CHECK: alloca
%token = async.execute {
%0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x?xf32>, vector<2x2xf32>
call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> ()
async.yield
}
return %token : !async.token
}