[mlir][Vector] Patterns flattening vector transfers to 1D

This is the second part of https://reviews.llvm.org/D114993 after slicing
into 2 independent commits.

This is needed at the moment to get good codegen from 2d vector.transfer
ops that aim to compile to SIMD load/store instructions but that can
only do so if the whole 2d transfer shape is handled in one piece, in
particular taking advantage of the memref being contiguous rowmajor.

For instance, if the target architecture has 128bit SIMD then we would
expect that contiguous row-major transfers of <4x4xi8> map to one SIMD
load/store instruction each.

The current generic lowering of multi-dimensional vector.transfer ops
can't achieve that because it peels dimensions one by one, so a transfer
of <4x4xi8> becomes 4 transfers of <4xi8>.

The new patterns here are only enabled for now by
 -test-vector-transfer-flatten-patterns.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114993
This commit is contained in:
Benoit Jacob 2021-12-13 20:00:28 +00:00 committed by Nicolas Vasilache
parent d1327f8a57
commit aba437ceb2
6 changed files with 241 additions and 2 deletions

View File

@ -67,13 +67,21 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
/// pairs or forward write-read pairs.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
/// Collect a set of leading one dimension removal patterns.
/// Collect a set of one dimension removal patterns.
///
/// These patterns insert rank-reducing memref.subview ops to remove one
/// dimensions. With them, there are more chances that we can avoid
/// potentially exensive vector.shape_cast operations.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
///
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
/// to transform multiple small n-D transfers into a larger 1-D transfer where
/// the memref contiguity properties allow it.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after

View File

@ -531,6 +531,11 @@ bool isStrided(MemRefType t);
/// Return null if the layout is not compatible with a strided layout.
AffineMap getStridedLinearLayoutMap(MemRefType t);
/// Helper determining if a memref is static-shape and contiguous-row-major
/// layout, while still allowing for an arbitrary offset (any static or
/// dynamic value).
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
} // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H

View File

@ -227,7 +227,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
MemRefType inputType = input.getType().cast<MemRefType>();
assert(inputType.hasStaticShape());
MemRefType resultType = dropUnitDims(inputType);
if (resultType == inputType)
if (canonicalizeStridedLayout(resultType) ==
canonicalizeStridedLayout(inputType))
return input;
SmallVector<int64_t> subviewOffsets(inputType.getRank(), 0);
SmallVector<int64_t> subviewStrides(inputType.getRank(), 1);
@ -333,6 +334,130 @@ class TransferWriteDropUnitDimsPattern
}
};
/// Creates a memref.collapse_shape collapsing all of the dimensions of the
/// input into a 1D shape.
// TODO: move helper function
static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
Value rankReducedInput =
rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
ShapedType rankReducedInputType =
rankReducedInput.getType().cast<ShapedType>();
if (rankReducedInputType.getRank() == 1)
return rankReducedInput;
ReassociationIndices indices;
for (int i = 0; i < rankReducedInputType.getRank(); ++i)
indices.push_back(i);
return rewriter.create<memref::CollapseShapeOp>(
loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
}
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.vector();
VectorType vectorType = vector.getType().cast<VectorType>();
Value source = transferReadOp.source();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
// Already 1D, nothing to do.
return failure();
if (!isStaticShapeAndContiguousRowMajor(sourceType))
return failure();
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
// This pattern requires the source to already be rank-reduced.
return failure();
if (sourceType.getNumElements() != vectorType.getNumElements())
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.permutation_map().isMinorIdentity())
return failure();
if (transferReadOp.mask())
return failure();
if (llvm::any_of(transferReadOp.indices(),
[](Value v) { return !isZero(v); }))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
sourceType.getElementType());
Value source1d =
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
Value read1d = rewriter.create<vector::TransferReadOp>(
loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, vector.getType().cast<VectorType>(), read1d);
return success();
}
};
/// Rewrites contiguous row-major vector.transfer_write ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.vector();
VectorType vectorType = vector.getType().cast<VectorType>();
Value source = transferWriteOp.source();
MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
// Already 1D, nothing to do.
return failure();
if (!isStaticShapeAndContiguousRowMajor(sourceType))
return failure();
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
// This pattern requires the source to already be rank-reduced.
return failure();
if (sourceType.getNumElements() != vectorType.getNumElements())
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
if (!transferWriteOp.permutation_map().isMinorIdentity())
return failure();
if (transferWriteOp.mask())
return failure();
if (llvm::any_of(transferWriteOp.indices(),
[](Value v) { return !isZero(v); }))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
sourceType.getElementType());
Value source1d =
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
Value vector1d =
rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
ValueRange{c0}, identityMap1D);
rewriter.eraseOp(transferWriteOp);
return success();
}
};
} // namespace
void mlir::vector::transferOpflowOpt(FuncOp func) {
@ -358,3 +483,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
patterns.getContext());
populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
RewritePatternSet &patterns) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext());
populateShapeCastFoldingPatterns(patterns);
}

View File

@ -1168,3 +1168,40 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
return AffineMap();
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
}
/// Return the AffineExpr representation of the offset, assuming `memRefType`
/// is a strided memref.
static AffineExpr getOffsetExpr(MemRefType memrefType) {
SmallVector<AffineExpr> strides;
AffineExpr offset;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
assert(false && "expected strided memref");
return offset;
}
/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
/// `offset` AffineExpr.
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType,
AffineExpr offset) {
AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
AffineExpr contiguousRowMajor = canonical + offset;
AffineMap contiguousRowMajorMap =
AffineMap::inferFromExprList({contiguousRowMajor})[0];
return MemRefType::get(shape, elementType, contiguousRowMajorMap);
}
/// Helper determining if a memref is static-shape and contiguous-row-major
/// layout, while still allowing for an arbitrary offset (any static or
/// dynamic value).
bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
if (!memrefType.hasStaticShape())
return false;
AffineExpr offset = getOffsetExpr(memrefType);
MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
memrefType.getContext(), memrefType.getShape(),
memrefType.getElementType(), offset);
return canonicalizeStridedLayout(memrefType) ==
canonicalizeStridedLayout(contiguousRowMajorMemRefType);
}

View File

@ -0,0 +1,35 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
func @transfer_read_flattenable_with_offset(
%arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, vector<5x4x3x2xi8>
return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_flattenable_with_offset
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// C-HECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
// C-HECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
// C-HECK: return %[[VEC2D]]
// -----
func @transfer_write_flattenable_with_offset(
%arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<5x4x3x2xi8>, memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>
return
}
// C-HECK-LABEL: func @transfer_write_flattenable_with_offset
// C-HECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// C-HECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// C-HECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]

View File

@ -598,6 +598,25 @@ struct TestVectorTransferDropUnitDimsPatterns
}
};
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-transfer-flatten-patterns";
}
StringRef getDescription() const final {
return "Test patterns to rewrite contiguous row-major N-dimensional "
"vector.transfer_{read,write} ops into 1D transfers";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // namespace
namespace mlir {
@ -630,6 +649,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();
}
} // namespace test
} // namespace mlir