[mlir][vector] Add patterns to cast away leading 1-dim

This patch adds patterns to use vector.shape_cast to cast
away leading 1-dimensions from a few vector operations.
It allows exposing more canonical forms of vector.transfer_read,
vector.transfer_write, vector_extract_strided_slice, and
vector.insert_strided_slice. With this, we can have more
opportunity to cancelling extract/insert ops or forwarding
write/read ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95873
This commit is contained in:
Lei Zhang 2021-02-05 08:55:32 -05:00
parent 2fbbb18c1d
commit 874ce9b80f
4 changed files with 269 additions and 0 deletions

View File

@ -35,6 +35,15 @@ void populateVectorToVectorCanonicalizationPatterns(
void populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
/// to expose more canonical forms of read/write/insert/extract operations.
/// With them, there are more chances that we can cancel out extract-insert
/// pairs or forward write-read pairs.
void populateCastAwayVectorLeadingOneDimPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
/// Collect a set of vector slices transformation patterns:
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
/// Useful for clients that want to express all vector "slices"

View File

@ -2607,6 +2607,186 @@ struct TransferWriteInsertPattern
}
};
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
ArrayRef<int64_t> newShape =
oldShape.drop_while([](int64_t dim) { return dim == 1; });
// Make sure we have at least 1 dimension per vector type requirements.
if (newShape.empty())
newShape = oldShape.take_back();
return VectorType::get(newShape, oldType.getElementType());
}
// Casts away leading one dimensions in vector.extract_strided_slice's vector
// input by inserting vector.shape_cast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
// vector.extract_strided_slice requires the input and output vector to have
// the same rank. Here we drop leading one dimensions from the input vector
// type to make sure we don't cause mismatch.
VectorType oldSrcType = extractOp.getVectorType();
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
if (newSrcType.getRank() == oldSrcType.getRank())
return failure();
int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
VectorType oldDstType = extractOp.getType();
VectorType newDstType =
VectorType::get(oldDstType.getShape().drop_front(dropCount),
oldDstType.getElementType());
Location loc = extractOp.getLoc();
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
loc, newSrcType, extractOp.vector());
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
auto newOffsets = rewriter.getArrayAttr(
extractOp.offsets().getValue().drop_front(dropCount));
auto newSizes = rewriter.getArrayAttr(
extractOp.sizes().getValue().drop_front(dropCount));
auto newStrides = rewriter.getArrayAttr(
extractOp.strides().getValue().drop_front(dropCount));
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
newExtractOp);
return success();
}
};
// Casts away leading one dimensions in vector.extract_strided_slice's vector
// inputs by inserting vector.shape_cast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
PatternRewriter &rewriter) const override {
VectorType oldSrcType = insertOp.getSourceVectorType();
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
VectorType oldDstType = insertOp.getDestVectorType();
VectorType newDstType = trimLeadingOneDims(oldDstType);
if (newSrcType.getRank() == oldSrcType.getRank() &&
newDstType.getRank() == oldDstType.getRank())
return failure();
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
loc, newSrcType, insertOp.source());
Value newDstVector =
rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
auto newOffsets = rewriter.getArrayAttr(
insertOp.offsets().getValue().take_back(newDstType.getRank()));
auto newStrides = rewriter.getArrayAttr(
insertOp.strides().getValue().take_back(newSrcType.getRank()));
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
newInsertOp);
return success();
}
};
// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
struct CastAwayTransferReadLeadingOneDim
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
auto shapedType = read.source().getType().cast<ShapedType>();
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
VectorType oldType = read.getVectorType();
VectorType newType = trimLeadingOneDims(oldType);
if (newType == oldType)
return failure();
AffineMap oldMap = read.permutation_map();
ArrayRef<AffineExpr> newResults =
oldMap.getResults().take_back(newType.getRank());
AffineMap newMap =
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
ArrayAttr mask;
if (read.masked())
mask = rewriter.getArrayAttr(
read.maskedAttr().getValue().take_back(newType.getRank()));
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.source(), read.indices(), newMap,
read.padding(), mask);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
return success();
}
};
// Turns vector.transfer_write on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_write on vector without leading
// 1 dimensions.
struct CastAwayTransferWriteLeadingOneDim
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
VectorType oldType = write.getVectorType();
VectorType newType = trimLeadingOneDims(oldType);
if (newType == oldType)
return failure();
AffineMap oldMap = write.permutation_map();
ArrayRef<AffineExpr> newResults =
oldMap.getResults().take_back(newType.getRank());
AffineMap newMap =
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
ArrayAttr mask;
if (write.masked())
mask = rewriter.getArrayAttr(
write.maskedAttr().getValue().take_back(newType.getRank()));
auto newVector = rewriter.create<vector::ShapeCastOp>(
write.getLoc(), newType, write.vector());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.source(), write.indices(), newMap, mask);
return success();
}
};
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
@ -2622,6 +2802,15 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
// clang-format on
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
context);
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);

View File

@ -601,3 +601,73 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
: vector<4x4xf32>, tensor<4x4xf32>
return %r : tensor<4x4xf32>
}
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
// CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x1x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
// CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
// CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x8x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
// CHECK: vector.shape_cast %{{.+}} : vector<1x1x1xf16> to vector<1xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
return %0: vector<1x1x1xf16>
}
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
// CHECK: %[[C0:.+]] = constant 0 : index
%c0 = constant 0 : index
// CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16
%f0 = constant 0. : f16
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
}
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = constant 0 : index
%f0 = constant 0. : f16
// CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x1x1x1xf16>, vector<1x1xf16>
return %0: vector<1x1xf16>
}
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = constant 0 : index
%c0 = constant 0 : index
// CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16>
return
}
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = constant 0 : index
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
return
}

View File

@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
}
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}