From 64f7fb5dfca14bead0e4b12142da2135f950034f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 23 Apr 2021 18:11:07 +0900 Subject: [PATCH] [mlir] Support masked N-D vector transfer ops in ProgressiveVectorToSCF. Mask vectors are handled similar to data vectors in N-D TransferWriteOp. They are copied into a temporary memory buffer, which can be indexed into with non-constant values. Differential Revision: https://reviews.llvm.org/D101136 --- .../VectorToSCF/ProgressiveVectorToSCF.cpp | 154 +++++++++++++----- .../Vector/CPU/test-transfer-read-2d.mlir | 24 ++- 2 files changed, 132 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp index d42bd67082e8..08aca49c7af4 100644 --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -56,16 +56,34 @@ static MemRefType unpackOneDim(MemRefType type) { vectorType.getElementType())); } -// TODO: Parallelism and threadlocal considerations. -static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) { +/// Helper data structure for data and mask buffers. +struct BufferAllocs { + Value dataBuffer; + Value maskBuffer; +}; + +/// Allocate temporary buffers for data (vector) and mask (if present). +/// TODO: Parallelism and threadlocal considerations. +template +static BufferAllocs allocBuffers(OpTy xferOp) { auto &b = ScopedContext::getBuilderRef(); OpBuilder::InsertionGuard guard(b); Operation *scope = - op->getParentWithTrait(); + xferOp->template getParentWithTrait(); assert(scope && "Expected op to be inside automatic allocation scope"); b.setInsertionPointToStart(&scope->getRegion(0).front()); - Value res = memref_alloca(type); - return res; + + BufferAllocs result; + auto bufferType = MemRefType::get({}, xferOp.getVectorType()); + result.dataBuffer = memref_alloca(bufferType).value; + + if (xferOp.mask()) { + auto maskType = MemRefType::get({}, xferOp.mask().getType()); + result.maskBuffer = memref_alloca(maskType).value; + memref_store(xferOp.mask(), result.maskBuffer); + } + + return result; } /// Given a vector transfer op, calculate which dimension of the `source` @@ -238,6 +256,16 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) { return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); } +/// Given a transfer op, find the memref from which the mask is loaded. This +/// is similar to Strategy::getBuffer. +template +static Value getMaskBuffer(OpTy xferOp) { + assert(xferOp.mask() && "Expected that transfer op has mask"); + auto loadOp = xferOp.mask().template getDefiningOp(); + assert(loadOp && "Expected transfer op mask produced by LoadOp"); + return loadOp.getMemRef(); +} + /// Codegen strategy, depending on the operation. template struct Strategy; @@ -266,9 +294,9 @@ struct Strategy { return getStoreOp(xferOp).getMemRef(); } - /// Retrieve the indices of the current StoreOp. - static void getStoreIndices(TransferReadOp xferOp, - SmallVector &indices) { + /// Retrieve the indices of the current StoreOp that stores into the buffer. + static void getBufferIndices(TransferReadOp xferOp, + SmallVector &indices) { auto storeOp = getStoreOp(xferOp); auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); indices.append(prevIndices.begin(), prevIndices.end()); @@ -300,10 +328,11 @@ struct Strategy { /// /// Note: The loop and type cast are generated in TransferOpConversion. /// The original TransferReadOp and store op are deleted in `cleanup`. - static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, - Value buffer, Value iv) { + /// Note: The `mask` operand is set in TransferOpConversion. + static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp, + Value buffer, Value iv) { SmallVector storeIndices; - getStoreIndices(xferOp, storeIndices); + getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); SmallVector xferIndices; @@ -321,6 +350,7 @@ struct Strategy { newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr()); memref_store(newXfer, buffer, storeIndices); + return newXfer.getDefiningOp(); } /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write @@ -329,7 +359,7 @@ struct Strategy { OpBuilder &/*builder*/, TransferReadOp xferOp, Value buffer, Value iv) { SmallVector storeIndices; - getStoreIndices(xferOp, storeIndices); + getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); auto bufferType = buffer.getType().dyn_cast(); @@ -361,9 +391,9 @@ struct Strategy { return loadOp.getMemRef(); } - /// Retrieve the indices of the current LoadOp. - static void getLoadIndices(TransferWriteOp xferOp, - SmallVector &indices) { + /// Retrieve the indices of the current LoadOp that loads from the buffer. + static void getBufferIndices(TransferWriteOp xferOp, + SmallVector &indices) { auto loadOp = xferOp.vector().getDefiningOp(); auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); indices.append(prevIndices.begin(), prevIndices.end()); @@ -378,10 +408,10 @@ struct Strategy { /// to memory. /// /// Note: For more details, see comments on Strategy. - static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, - Value buffer, Value iv) { + static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, + Value buffer, Value iv) { SmallVector loadIndices; - getLoadIndices(xferOp, loadIndices); + getBufferIndices(xferOp, loadIndices); loadIndices.push_back(iv); SmallVector xferIndices; @@ -397,6 +427,8 @@ struct Strategy { if (vecType.getRank() > kTargetRank) newXfer.op->setAttr(kPassLabel, builder.getUnitAttr()); + + return newXfer; } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. @@ -416,8 +448,6 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) { return failure(); if (xferOp.getVectorType().getRank() <= kTargetRank) return failure(); - if (xferOp.mask()) - return failure(); return success(); } @@ -442,6 +472,8 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) { /// memref.store %1, %0[] : memref> /// %vec = memref.load %0[] : memref> /// ``` +/// +/// Note: A second temporary buffer may be allocated for the `mask` operand. struct PrepareTransferReadConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -452,12 +484,16 @@ struct PrepareTransferReadConversion return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); - auto allocType = MemRefType::get({}, xferOp.getVectorType()); - auto buffer = setAllocAtFunctionEntry(allocType, xferOp); + auto buffers = allocBuffers(xferOp); auto *newXfer = rewriter.clone(*xferOp.getOperation()); newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); - memref_store(newXfer->getResult(0), buffer); - rewriter.replaceOpWithNewOp(xferOp, buffer); + if (xferOp.mask()) { + auto loadedMask = memref_load(buffers.maskBuffer); + dyn_cast(newXfer).maskMutable().assign(loadedMask); + } + + memref_store(newXfer->getResult(0), buffers.dataBuffer); + rewriter.replaceOpWithNewOp(xferOp, buffers.dataBuffer); return success(); } @@ -484,6 +520,8 @@ struct PrepareTransferReadConversion /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } /// : vector<5x4xf32>, memref /// ``` +/// +/// Note: A second temporary buffer may be allocated for the `mask` operand. struct PrepareTransferWriteConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -494,16 +532,20 @@ struct PrepareTransferWriteConversion return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); - auto allocType = MemRefType::get({}, xferOp.getVectorType()); - auto buffer = setAllocAtFunctionEntry(allocType, xferOp); - memref_store(xferOp.vector(), buffer); - auto loadedVec = memref_load(buffer); - + auto buffers = allocBuffers(xferOp); + memref_store(xferOp.vector(), buffers.dataBuffer); + auto loadedVec = memref_load(buffers.dataBuffer); rewriter.updateRootInPlace(xferOp, [&]() { xferOp.vectorMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); }); + if (xferOp.mask()) { + auto loadedMask = memref_load(buffers.maskBuffer); + rewriter.updateRootInPlace( + xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); }); + } + return success(); } }; @@ -535,16 +577,28 @@ struct TransferOpConversion : public OpRewritePattern { return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); - // How the buffer can be found depends on OpTy. - auto buffer = Strategy::getBuffer(xferOp); - auto bufferType = buffer.getType().template dyn_cast(); - auto castedType = unpackOneDim(bufferType); - auto casted = vector_type_cast(castedType, buffer); + + // Find and cast data buffer. How the buffer can be found depends on OpTy. + auto dataBuffer = Strategy::getBuffer(xferOp); + auto dataBufferType = dataBuffer.getType().template dyn_cast(); + auto castedDataType = unpackOneDim(dataBufferType); + auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer); + + // If the xferOp has a mask: Find and cast mask buffer. + Value castedMaskBuffer; + if (xferOp.mask()) { + auto maskBuffer = getMaskBuffer(xferOp); + auto maskBufferType = + maskBuffer.getType().template dyn_cast(); + auto castedMaskType = unpackOneDim(maskBufferType); + castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer); + } // Loop bounds and step. auto lb = std_constant_index(0).value; auto ub = std_constant_index( - castedType.getDimSize(castedType.getRank() - 1)).value; + castedDataType.getDimSize(castedDataType.getRank() - 1)) + .value; auto step = std_constant_index(1).value; // Generate for loop. @@ -555,11 +609,31 @@ struct TransferOpConversion : public OpRewritePattern { ScopedContext scope(b, loc); generateInBoundsCheck( xferOp, iv, b, unpackedDim(xferOp), - /*inBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) { - Strategy::rewriteOp(b, xferOp, casted, iv); - }, /*outOfBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) { - Strategy::handleOutOfBoundsDim(b, xferOp, casted, iv); - }); + /*inBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + // Create new transfer op. + OpTy newXfer = + Strategy::rewriteOp(b, xferOp, castedDataBuffer, iv); + + // If old transfer op has a mask: Set mask on new transfer op. + if (xferOp.mask()) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(newXfer); // Insert load before newXfer. + + SmallVector loadIndices; + Strategy::getBufferIndices(xferOp, loadIndices); + loadIndices.push_back(iv); + + auto mask = memref_load(castedMaskBuffer, loadIndices); + rewriter.updateRootInPlace( + newXfer, [&]() { newXfer.maskMutable().assign(mask); }); + } + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + Strategy::handleOutOfBoundsDim(b, xferOp, castedDataBuffer, + iv); + }); b.create(loc); }); diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir index cbe0aa52a437..f4eef8b98b76 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -1,8 +1,3 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ @@ -17,6 +12,19 @@ func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { return } +func @transfer_read_2d_mask(%A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + func @transfer_read_2d_transposed( %A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 @@ -80,7 +88,10 @@ func @entry() { call @transfer_write_2d(%A, %c3, %c1) : (memref, index, index) -> () // Read shifted by 0 and pad with -42: call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () - // Same as above, but transposed + // Same as above, but apply a mask + call @transfer_read_2d_mask(%A, %c0, %c0) + : (memref, index, index) -> () + // Same as above, but without mask and transposed call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref, index, index) -> () // Second vector dimension is a broadcast @@ -92,5 +103,6 @@ func @entry() { // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) +// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )