From aba437ceb2379f219935b98a10ca3c5081f0c8b7 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Mon, 13 Dec 2021 20:00:28 +0000 Subject: [PATCH] [mlir][Vector] Patterns flattening vector transfers to 1D This is the second part of https://reviews.llvm.org/D114993 after slicing into 2 independent commits. This is needed at the moment to get good codegen from 2d vector.transfer ops that aim to compile to SIMD load/store instructions but that can only do so if the whole 2d transfer shape is handled in one piece, in particular taking advantage of the memref being contiguous rowmajor. For instance, if the target architecture has 128bit SIMD then we would expect that contiguous row-major transfers of <4x4xi8> map to one SIMD load/store instruction each. The current generic lowering of multi-dimensional vector.transfer ops can't achieve that because it peels dimensions one by one, so a transfer of <4x4xi8> becomes 4 transfers of <4xi8>. The new patterns here are only enabled for now by -test-vector-transfer-flatten-patterns. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114993 --- mlir/include/mlir/Dialect/Vector/VectorOps.h | 10 +- mlir/include/mlir/IR/BuiltinTypes.h | 5 + .../Vector/VectorTransferOpTransforms.cpp | 135 +++++++++++++++++- mlir/lib/IR/BuiltinTypes.cpp | 37 +++++ .../Vector/vector-transfer-flatten.mlir | 35 +++++ .../Dialect/Vector/TestVectorTransforms.cpp | 21 +++ 6 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-transfer-flatten.mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index c6b63a949f64..14bd03968fcf 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,13 +67,21 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); -/// Collect a set of leading one dimension removal patterns. +/// Collect a set of one dimension removal patterns. /// /// These patterns insert rank-reducing memref.subview ops to remove one /// dimensions. With them, there are more chances that we can avoid /// potentially exensive vector.shape_cast operations. void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns to flatten n-D vector transfers on contiguous +/// memref. +/// +/// These patterns insert memref.collapse_shape + vector.shape_cast patterns +/// to transform multiple small n-D transfers into a larger 1-D transfer where +/// the memref contiguity properties allow it. +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 296b12c42792..be98b577ffab 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -531,6 +531,11 @@ bool isStrided(MemRefType t); /// Return null if the layout is not compatible with a strided layout. AffineMap getStridedLinearLayoutMap(MemRefType t); +/// Helper determining if a memref is static-shape and contiguous-row-major +/// layout, while still allowing for an arbitrary offset (any static or +/// dynamic value). +bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType); + } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp index c9438c4a28f4..9b1ae7a40226 100644 --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -227,7 +227,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, MemRefType inputType = input.getType().cast(); assert(inputType.hasStaticShape()); MemRefType resultType = dropUnitDims(inputType); - if (resultType == inputType) + if (canonicalizeStridedLayout(resultType) == + canonicalizeStridedLayout(inputType)) return input; SmallVector subviewOffsets(inputType.getRank(), 0); SmallVector subviewStrides(inputType.getRank(), 1); @@ -333,6 +334,130 @@ class TransferWriteDropUnitDimsPattern } }; +/// Creates a memref.collapse_shape collapsing all of the dimensions of the +/// input into a 1D shape. +// TODO: move helper function +static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + Value rankReducedInput = + rankReducingSubviewDroppingUnitDims(rewriter, loc, input); + ShapedType rankReducedInputType = + rankReducedInput.getType().cast(); + if (rankReducedInputType.getRank() == 1) + return rankReducedInput; + ReassociationIndices indices; + for (int i = 0; i < rankReducedInputType.getRank(); ++i) + indices.push_back(i); + return rewriter.create( + loc, rankReducedInput, std::array{indices}); +} + +/// Rewrites contiguous row-major vector.transfer_read ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_read has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // Contiguity check is valid on tensors only. + if (!sourceType) + return failure(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferReadOp.mask()) + return failure(); + if (llvm::any_of(transferReadOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value read1d = rewriter.create( + loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + rewriter.replaceOpWithNewOp( + transferReadOp, vector.getType().cast(), read1d); + return success(); + } +}; + +/// Rewrites contiguous row-major vector.transfer_write ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_write has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // Contiguity check is valid on tensors only. + if (!sourceType) + return failure(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferWriteOp.mask()) + return failure(); + if (llvm::any_of(transferWriteOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value vector1d = + rewriter.create(loc, vectorType1d, vector); + rewriter.create(loc, vector1d, source1d, + ValueRange{c0}, identityMap1D); + rewriter.eraseOp(transferWriteOp); + return success(); + } +}; + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -358,3 +483,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns( patterns.getContext()); populateShapeCastFoldingPatterns(patterns); } + +void mlir::vector::populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 10c38a86314f..8e408e440dc3 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -1168,3 +1168,40 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { return AffineMap(); return makeStridedLinearLayoutMap(strides, offset, t.getContext()); } + +/// Return the AffineExpr representation of the offset, assuming `memRefType` +/// is a strided memref. +static AffineExpr getOffsetExpr(MemRefType memrefType) { + SmallVector strides; + AffineExpr offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + assert(false && "expected strided memref"); + return offset; +} + +/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and +/// `offset` AffineExpr. +static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, + ArrayRef shape, + Type elementType, + AffineExpr offset) { + AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); + AffineExpr contiguousRowMajor = canonical + offset; + AffineMap contiguousRowMajorMap = + AffineMap::inferFromExprList({contiguousRowMajor})[0]; + return MemRefType::get(shape, elementType, contiguousRowMajorMap); +} + +/// Helper determining if a memref is static-shape and contiguous-row-major +/// layout, while still allowing for an arbitrary offset (any static or +/// dynamic value). +bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { + if (!memrefType.hasStaticShape()) + return false; + AffineExpr offset = getOffsetExpr(memrefType); + MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( + memrefType.getContext(), memrefType.getShape(), + memrefType.getElementType(), offset); + return canonicalizeStridedLayout(memrefType) == + canonicalizeStridedLayout(contiguousRowMajorMemRefType); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir new file mode 100644 index 000000000000..68a6779461d6 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s + +func @transfer_read_flattenable_with_offset( + %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// C-HECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// C-HECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// C-HECK: return %[[VEC2D]] + +// ----- + +func @transfer_write_flattenable_with_offset( + %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]> + return +} + +// C-HECK-LABEL: func @transfer_write_flattenable_with_offset +// C-HECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// C-HECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> +// C-HECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index cf33b0d7117d..a0d5a1b915ff 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -598,6 +598,25 @@ struct TestVectorTransferDropUnitDimsPatterns } }; +struct TestFlattenVectorTransferPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-flatten-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to rewrite contiguous row-major N-dimensional " + "vector.transfer_{read,write} ops into 1D transfers"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateFlattenVectorTransferPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -630,6 +649,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir