From 3c3810e72e8b5d324be3c3de6faf177144653408 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 2 Feb 2022 05:21:02 -0500 Subject: [PATCH] [mlir][vector] Avoid hoisting alloca'ed temporary buffers across AutomaticAllocationScope This revision avoids incorrect hoisting of alloca'd buffers across an AutomaticAllocationScope boundary. In the more general case, we will probably need a ParallelScope-like interface. Differential Revision: https://reviews.llvm.org/D118768 --- .../Conversion/VectorToSCF/VectorToSCF.cpp | 15 +++++++++++---- .../VectorTransferSplitRewritePatterns.cpp | 16 +++++++++++++--- .../Conversion/VectorToSCF/vector-to-scf.mlir | 19 +++++++++++++++++++ .../vector-transfer-full-partial-split.mlir | 19 +++++++++++++++++++ 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 6cdad451e328..499f4403d317 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -267,15 +267,22 @@ struct BufferAllocs { Value maskBuffer; }; +// TODO: Parallelism and threadlocal considerations with a ParallelScope trait. +static Operation *getAutomaticAllocationScope(Operation *op) { + Operation *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + return scope; +} + /// Allocate temporary buffers for data (vector) and mask (if present). -/// TODO: Parallelism and threadlocal considerations. template static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { Location loc = xferOp.getLoc(); OpBuilder::InsertionGuard guard(b); - Operation *scope = - xferOp->template getParentWithTrait(); - assert(scope && "Expected op to be inside automatic allocation scope"); + Operation *scope = getAutomaticAllocationScope(xferOp); + assert(scope->getNumRegions() == 1 && + "AutomaticAllocationScope with >1 regions"); b.setInsertionPointToStart(&scope->getRegion(0).front()); BufferAllocs result; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index ff3a6012f2d5..2cbc95d5d0f8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -438,6 +438,14 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b, }); } +// TODO: Parallelism and threadlocal considerations with a ParallelScope trait. +static Operation *getAutomaticAllocationScope(Operation *op) { + Operation *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + return scope; +} + /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. /// @@ -538,12 +546,14 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( // Top of the function `alloc` for transient storage. Value alloc; { - FuncOp funcOp = xferOp->getParentOfType(); RewriterBase::InsertionGuard guard(b); - b.setInsertionPointToStart(&funcOp.getRegion().front()); + Operation *scope = getAutomaticAllocationScope(xferOp); + assert(scope->getNumRegions() == 1 && + "AutomaticAllocationScope with >1 regions"); + b.setInsertionPointToStart(&scope->getRegion(0).front()); auto shape = xferOp.getVectorType().getShape(); Type elementType = xferOp.getVectorType().getElementType(); - alloc = b.create(funcOp.getLoc(), + alloc = b.create(scope->getLoc(), MemRefType::get(shape, elementType), ValueRange{}, b.getI64IntegerAttr(32)); } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 7cddb46f094e..15b70caa930f 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -481,3 +481,22 @@ func @transfer_write_strided(%A : vector<4xf32>, %B : memref<8x4xf32, affine_map // CHECK-LABEL: transfer_write_strided( // CHECK: scf.for // CHECK: store + +// ----- + +func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () + +// CHECK-LABEL: transfer_read_within_async_execute +func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.token { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + // CHECK-NOT: alloca + // CHECK: async.execute + // CHECK: alloca + %token = async.execute { + %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32> + call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () + async.yield + } + return %token : !async.token +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir index 9a10482e027a..ace977fb1e7a 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -393,3 +393,22 @@ func @split_vector_transfer_write_strided_2d( // LINALG: } // LINALG: return // LINALG: } + +// ----- + +func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () + +// CHECK-LABEL: transfer_read_within_async_execute +func @transfer_read_within_async_execute(%A : memref) -> !async.token { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + // CHECK-NOT: alloca + // CHECK: async.execute + // CHECK: alloca + %token = async.execute { + %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref, vector<2x2xf32> + call @fake_side_effecting_fun(%0) : (vector<2x2xf32>) -> () + async.yield + } + return %token : !async.token +}