diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h index 0d18c5aa782d..835ad18a79ad 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -17,6 +17,11 @@ namespace mlir { class MLIRContext; class OwningRewritePatternList; +class VectorTransferOpInterface; + +namespace scf { +class IfOp; +} // namespace scf /// Collect a set of patterns to convert from the Vector dialect to itself. /// Should be merged with populateVectorToSCFLoweringPattern. @@ -104,6 +109,65 @@ private: FilterConstraintType filter; }; +/// Split a vector.transfer operation into an unmasked fastpath vector.transfer +/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result +/// is `success, the `ifOp` points to the newly created conditional upon +/// function return. To accomodate for the fact that the original +/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0] +/// in the temporary buffer, the scf.if op returns a view and values of type +/// index. At this time, only vector.transfer_read is implemented. +/// +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// scf.yield %0 : memref, index, index +/// } else { +/// %2 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// %3 = vector.type_cast %extra_alloc : memref<...> to +/// memref> store %2, %3[] : memref> %4 = +/// memref_cast %extra_alloc: memref to memref scf.yield %4 : +/// memref, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} +/// ``` +/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +LogicalResult +splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp); +LogicalResult splitFullAndPartialTransfer(OpBuilder &b, + VectorTransferOpInterface xferOp, + scf::IfOp *ifOp = nullptr); + +/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern +/// may take an extra filter to perform selection at a finer granularity. +struct VectorTransferFullPartialRewriter : public RewritePattern { + using FilterConstraintType = + std::function; + + explicit VectorTransferFullPartialRewriter( + MLIRContext *context, + FilterConstraintType filter = + [](VectorTransferOpInterface op) { return success(); }, + PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {} + + /// Performs the rewrite. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + FilterConstraintType filter; +}; + } // namespace vector //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index aefbb7d47117..218715318a86 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -160,6 +160,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*defaultImplementation=*/ "return $_op.getMemRefType().getRank() - $_op.getTransferRank();" >, + InterfaceMethod< + /*desc=*/[{ Returns true if at least one of the dimensions is masked.}], + /*retTy=*/"bool", + /*methodName=*/"hasMaskedDim", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + for (unsigned idx = 0, e = $_op.getTransferRank(); idx < e; ++idx) + if ($_op.isMaskedDim(idx)) + return true; + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Helper function to account for the fact that `permutationMap` results and diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 197b1c62274b..573b822503f3 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -12,9 +12,13 @@ #include +#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/Vector/VectorUtils.h" @@ -1985,6 +1989,236 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, } // namespace mlir +static Optional extractConstantIndex(Value v) { + if (auto cstOp = v.getDefiningOp()) + return cstOp.getValue(); + if (auto affineApplyOp = v.getDefiningOp()) + if (affineApplyOp.getAffineMap().isSingleConstant()) + return affineApplyOp.getAffineMap().getSingleConstantResult(); + return None; +} + +// Missing foldings of scf.if make it necessary to perform poor man's folding +// eagerly, especially in the case of unrolling. In the future, this should go +// away once scf.if folds properly. +static Value createScopedFoldedSLE(Value v, Value ub) { + using namespace edsc::op; + auto maybeCstV = extractConstantIndex(v); + auto maybeCstUb = extractConstantIndex(ub); + if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) + return Value(); + return sle(v, ub); +} + +// Operates under a scoped context to build the condition to ensure that a +// particular VectorTransferOpInterface is unmasked. +static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) { + assert(xferOp.permutation_map().isMinorIdentity() && + "Expected minor identity map"); + Value inBoundsCond; + xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { + // Zip over the resulting vector shape and memref indices. + // If the dimension is known to be unmasked, it does not participate in the + // construction of `inBoundsCond`. + if (!xferOp.isMaskedDim(resultIdx)) + return; + int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); + using namespace edsc::op; + using namespace edsc::intrinsics; + // Fold or create the check that `index + vector_size` <= `memref_size`. + Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize); + Value cond = + createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx)); + if (!cond) + return; + // Conjunction over all dims for which we are in-bounds. + inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond; + }); + return inBoundsCond; +} + +LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition( + VectorTransferOpInterface xferOp) { + // TODO: expand support to these 2 cases. + if (!xferOp.permutation_map().isMinorIdentity()) + return failure(); + // TODO: relax this precondition. This will require rank-reducing subviews. + if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank()) + return failure(); + // Must have some masked dimension to be a candidate for splitting. + if (!xferOp.hasMaskedDim()) + return failure(); + // Don't split transfer operations under IfOp, this avoids applying the + // pattern recursively. + // TODO: improve the condition to make it more applicable. + if (xferOp.getParentOfType()) + return failure(); + return success(); +} + +MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { + if (MemRefCastOp::areCastCompatible(aT, bT)) + return aT; + if (aT.getRank() != bT.getRank()) + return MemRefType(); + int64_t aOffset, bOffset; + SmallVector aStrides, bStrides; + if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || + failed(getStridesAndOffset(bT, bStrides, bOffset)) || + aStrides.size() != bStrides.size()) + return MemRefType(); + + ArrayRef aShape = aT.getShape(), bShape = bT.getShape(); + int64_t resOffset; + SmallVector resShape(aT.getRank(), 0), + resStrides(bT.getRank(), 0); + for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { + resShape[idx] = + (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize; + resStrides[idx] = (aStrides[idx] == bStrides[idx]) + ? aStrides[idx] + : MemRefType::kDynamicStrideOrOffset; + } + resOffset = + (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset; + return MemRefType::get( + resShape, aT.getElementType(), + makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); +} + +/// Split a vector.transfer operation into an unmasked fastpath vector.transfer +/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result +/// is `success, the `ifOp` points to the newly created conditional upon +/// function return. To accomodate for the fact that the original +/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0] +/// in the temporary buffer, the scf.if op returns a view and values of type +/// index. At this time, only vector.transfer_read is implemented. +/// +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// scf.yield %0 : memref, index, index +/// } else { +/// %2 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// %3 = vector.type_cast %extra_alloc : memref<...> to +/// memref> store %2, %3[] : memref> %4 = +/// memref_cast %extra_alloc: memref to memref scf.yield %4 : +/// memref, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} +/// ``` +/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +LogicalResult mlir::vector::splitFullAndPartialTransfer( + OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) { + using namespace edsc; + using namespace edsc::intrinsics; + + assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && + "Expected splitFullAndPartialTransferPrecondition to hold"); + auto xferReadOp = dyn_cast(xferOp.getOperation()); + + // TODO: add support for write case. + if (!xferReadOp) + return failure(); + + OpBuilder::InsertionGuard guard(b); + if (xferOp.memref().getDefiningOp()) + b.setInsertionPointAfter(xferOp.memref().getDefiningOp()); + else + b.setInsertionPoint(xferOp); + ScopedContext scope(b, xferOp.getLoc()); + Value inBoundsCond = createScopedInBoundsCond( + cast(xferOp.getOperation())); + if (!inBoundsCond) + return failure(); + + // Top of the function `alloc` for transient storage. + Value alloc; + { + FuncOp funcOp = xferOp.getParentOfType(); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(&funcOp.getRegion().front()); + auto shape = xferOp.getVectorType().getShape(); + Type elementType = xferOp.getVectorType().getElementType(); + alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{}, + b.getI64IntegerAttr(32)); + } + + Value memref = xferOp.memref(); + SmallVector bools(xferOp.getTransferRank(), false); + auto unmaskedAttr = b.getBoolArrayAttr(bools); + + MemRefType compatibleMemRefType = getCastCompatibleMemRefType( + xferOp.getMemRefType(), alloc.getType().cast()); + + // Read case: full fill + partial copy -> unmasked vector.xfer_read. + Value zero = std_constant_index(0); + SmallVector returnTypes(1 + xferOp.getTransferRank(), + b.getIndexType()); + returnTypes[0] = compatibleMemRefType; + scf::IfOp fullPartialIfOp; + conditionBuilder( + returnTypes, inBoundsCond, + [&]() -> scf::ValueVector { + Value res = memref; + if (compatibleMemRefType != xferOp.getMemRefType()) + res = std_memref_cast(memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), + xferOp.indices().end()); + return viewAndIndices; + }, + [&]() -> scf::ValueVector { + Operation *newXfer = + ScopedContext::getBuilderRef().clone(*xferOp.getOperation()); + Value vector = cast(newXfer).vector(); + std_store(vector, vector_type_cast( + MemRefType::get({}, vector.getType()), alloc)); + + Value casted = std_memref_cast(alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), + zero); + + return viewAndIndices; + }, + &fullPartialIfOp); + if (ifOp) + *ifOp = fullPartialIfOp; + + // Unmask the existing read op, it always reads from a full buffer. + for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) + xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); + xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr); + + return success(); +} + +LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + auto xferOp = dyn_cast(op); + if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || + failed(filter(xferOp))) + return failure(); + rewriter.startRootUpdate(xferOp); + if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) { + rewriter.finalizeRootUpdate(xferOp); + return success(); + } + rewriter.cancelRootUpdate(xferOp); + return failure(); +} + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir new file mode 100644 index 000000000000..ef76247ee9d4 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s + +// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> +// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + +// CHECK-LABEL: split_vector_transfer_read_2d( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index +func @split_vector_transfer_read_2d(%A: memref, %i: index, %j: index) -> vector<4x8xf32> { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + + // CHECK-DAG: %[[c0:.*]] = constant 0 : index + // CHECK-DAG: %[[c8:.*]] = constant 8 : index + // CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 + // alloca for boundary full tile + // CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // CHECK: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref + // CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index + // %j + 8 <= dim(%A, 1) + // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index + // are both conds true + // CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1 + // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { + // inBounds, just yield %A + // CHECK: scf.yield %[[A]], %[[i]], %[[j]] : memref, index, index + // CHECK: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : + // CHECK-SAME: memref, vector<4x8xf32> + // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref> + // CHECK: store %[[slow]], %[[cast_alloc]][] : memref> + // CHECK: %[[yielded:.*]] = memref_cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref + // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // CHECK-SAME: memref, index, index + // CHECK: } + // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]] + // CHECK_SAME: {masked = [false, false]} : memref, vector<4x8xf32> + %1 = vector.transfer_read %A[%i, %j], %f0 : memref, vector<4x8xf32> + + // CHECK: return %[[res]] : vector<4x8xf32> + return %1: vector<4x8xf32> +} + +// CHECK-LABEL: split_vector_transfer_read_strided_2d( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index +func @split_vector_transfer_read_strided_2d( + %A: memref<7x8xf32, offset:?, strides:[?, 1]>, + %i: index, %j: index) -> vector<4x8xf32> { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + + // CHECK-DAG: %[[c0:.*]] = constant 0 : index + // CHECK-DAG: %[[c7:.*]] = constant 7 : index + // CHECK-DAG: %[[c8:.*]] = constant 8 : index + // CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 + // alloca for boundary full tile + // CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index + // %j + 8 <= dim(%A, 1) + // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index + // are both conds true + // CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1 + // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { + // inBounds but not cast-compatible: yield a memref_casted form of %A + // CHECK: %[[casted:.*]] = memref_cast %arg0 : + // CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // CHECK: scf.yield %[[casted]], %[[i]], %[[j]] : + // CHECK-SAME: memref, index, index + // CHECK: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : + // CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32> + // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref> + // CHECK: store %[[slow]], %[[cast_alloc]][] : + // CHECK-SAME: memref> + // CHECK: %[[yielded:.*]] = memref_cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref + // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // CHECK-SAME: memref, index, index + // CHECK: } + // CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} : + // CHECK-SAME: memref, vector<4x8xf32> + %1 = vector.transfer_read %A[%i, %j], %f0 : + memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32> + + // CHECK: return %[[res]] : vector<4x8xf32> + return %1 : vector<4x8xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 2058706dcbdd..0bba74e76385 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -122,6 +122,17 @@ struct TestVectorUnrollingPatterns } }; +struct TestVectorTransferFullPartialSplitPatterns + : public PassWrapper { + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + patterns.insert(ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace namespace mlir { @@ -141,5 +152,10 @@ void registerTestVectorConversions() { PassRegistration contractionUnrollingPass( "test-vector-unrolling-patterns", "Test conversion patterns to unroll contract ops in the vector dialect"); + + PassRegistration + vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", + "Test conversion patterns to split " + "transfer ops via scf.if + linalg ops"); } } // namespace mlir