forked from OSchip/llvm-project
[mlir][vector] Add patterns to cast away leading 1-dim
This patch adds patterns to use vector.shape_cast to cast away leading 1-dimensions from a few vector operations. It allows exposing more canonical forms of vector.transfer_read, vector.transfer_write, vector_extract_strided_slice, and vector.insert_strided_slice. With this, we can have more opportunity to cancelling extract/insert ops or forwarding write/read ops. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D95873
This commit is contained in:
parent
2fbbb18c1d
commit
874ce9b80f
|
@ -35,6 +35,15 @@ void populateVectorToVectorCanonicalizationPatterns(
|
|||
void populateVectorToVectorTransformationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||
|
||||
/// Collect a set of leading one dimension removal patterns.
|
||||
///
|
||||
/// These patterns insert vector.shape_cast to remove leading one dimensions
|
||||
/// to expose more canonical forms of read/write/insert/extract operations.
|
||||
/// With them, there are more chances that we can cancel out extract-insert
|
||||
/// pairs or forward write-read pairs.
|
||||
void populateCastAwayVectorLeadingOneDimPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||
|
||||
/// Collect a set of vector slices transformation patterns:
|
||||
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
|
||||
/// Useful for clients that want to express all vector "slices"
|
||||
|
|
|
@ -2607,6 +2607,186 @@ struct TransferWriteInsertPattern
|
|||
}
|
||||
};
|
||||
|
||||
// Trims leading one dimensions from `oldType` and returns the result type.
|
||||
// Returns `vector<1xT>` if `oldType` only has one element.
|
||||
static VectorType trimLeadingOneDims(VectorType oldType) {
|
||||
ArrayRef<int64_t> oldShape = oldType.getShape();
|
||||
ArrayRef<int64_t> newShape =
|
||||
oldShape.drop_while([](int64_t dim) { return dim == 1; });
|
||||
// Make sure we have at least 1 dimension per vector type requirements.
|
||||
if (newShape.empty())
|
||||
newShape = oldShape.take_back();
|
||||
return VectorType::get(newShape, oldType.getElementType());
|
||||
}
|
||||
|
||||
// Casts away leading one dimensions in vector.extract_strided_slice's vector
|
||||
// input by inserting vector.shape_cast.
|
||||
struct CastAwayExtractStridedSliceLeadingOneDim
|
||||
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// vector.extract_strided_slice requires the input and output vector to have
|
||||
// the same rank. Here we drop leading one dimensions from the input vector
|
||||
// type to make sure we don't cause mismatch.
|
||||
VectorType oldSrcType = extractOp.getVectorType();
|
||||
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
|
||||
|
||||
if (newSrcType.getRank() == oldSrcType.getRank())
|
||||
return failure();
|
||||
|
||||
int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
|
||||
|
||||
VectorType oldDstType = extractOp.getType();
|
||||
VectorType newDstType =
|
||||
VectorType::get(oldDstType.getShape().drop_front(dropCount),
|
||||
oldDstType.getElementType());
|
||||
|
||||
Location loc = extractOp.getLoc();
|
||||
|
||||
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, newSrcType, extractOp.vector());
|
||||
|
||||
// The offsets/sizes/strides attribute can have a less number of elements
|
||||
// than the input vector's rank: it is meant for the leading dimensions.
|
||||
auto newOffsets = rewriter.getArrayAttr(
|
||||
extractOp.offsets().getValue().drop_front(dropCount));
|
||||
auto newSizes = rewriter.getArrayAttr(
|
||||
extractOp.sizes().getValue().drop_front(dropCount));
|
||||
auto newStrides = rewriter.getArrayAttr(
|
||||
extractOp.strides().getValue().drop_front(dropCount));
|
||||
|
||||
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
|
||||
newExtractOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Casts away leading one dimensions in vector.extract_strided_slice's vector
|
||||
// inputs by inserting vector.shape_cast.
|
||||
struct CastAwayInsertStridedSliceLeadingOneDim
|
||||
: public OpRewritePattern<vector::InsertStridedSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType oldSrcType = insertOp.getSourceVectorType();
|
||||
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
|
||||
VectorType oldDstType = insertOp.getDestVectorType();
|
||||
VectorType newDstType = trimLeadingOneDims(oldDstType);
|
||||
|
||||
if (newSrcType.getRank() == oldSrcType.getRank() &&
|
||||
newDstType.getRank() == oldDstType.getRank())
|
||||
return failure();
|
||||
|
||||
// Trim leading one dimensions from both operands.
|
||||
Location loc = insertOp.getLoc();
|
||||
|
||||
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, newSrcType, insertOp.source());
|
||||
Value newDstVector =
|
||||
rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
|
||||
|
||||
auto newOffsets = rewriter.getArrayAttr(
|
||||
insertOp.offsets().getValue().take_back(newDstType.getRank()));
|
||||
auto newStrides = rewriter.getArrayAttr(
|
||||
insertOp.strides().getValue().take_back(newSrcType.getRank()));
|
||||
|
||||
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
|
||||
newInsertOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Turns vector.transfer_read on vector with leading 1 dimensions into
|
||||
// vector.shape_cast followed by vector.transfer_read on vector without leading
|
||||
// 1 dimensions.
|
||||
struct CastAwayTransferReadLeadingOneDim
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto shapedType = read.source().getType().cast<ShapedType>();
|
||||
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
VectorType oldType = read.getVectorType();
|
||||
VectorType newType = trimLeadingOneDims(oldType);
|
||||
|
||||
if (newType == oldType)
|
||||
return failure();
|
||||
|
||||
AffineMap oldMap = read.permutation_map();
|
||||
ArrayRef<AffineExpr> newResults =
|
||||
oldMap.getResults().take_back(newType.getRank());
|
||||
AffineMap newMap =
|
||||
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
|
||||
rewriter.getContext());
|
||||
|
||||
ArrayAttr mask;
|
||||
if (read.masked())
|
||||
mask = rewriter.getArrayAttr(
|
||||
read.maskedAttr().getValue().take_back(newType.getRank()));
|
||||
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
read.getLoc(), newType, read.source(), read.indices(), newMap,
|
||||
read.padding(), mask);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Turns vector.transfer_write on vector with leading 1 dimensions into
|
||||
// vector.shape_cast followed by vector.transfer_write on vector without leading
|
||||
// 1 dimensions.
|
||||
struct CastAwayTransferWriteLeadingOneDim
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
|
||||
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
VectorType oldType = write.getVectorType();
|
||||
VectorType newType = trimLeadingOneDims(oldType);
|
||||
|
||||
if (newType == oldType)
|
||||
return failure();
|
||||
|
||||
AffineMap oldMap = write.permutation_map();
|
||||
ArrayRef<AffineExpr> newResults =
|
||||
oldMap.getResults().take_back(newType.getRank());
|
||||
AffineMap newMap =
|
||||
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
|
||||
rewriter.getContext());
|
||||
|
||||
ArrayAttr mask;
|
||||
if (write.masked())
|
||||
mask = rewriter.getArrayAttr(
|
||||
write.maskedAttr().getValue().take_back(newType.getRank()));
|
||||
|
||||
auto newVector = rewriter.create<vector::ShapeCastOp>(
|
||||
write.getLoc(), newType, write.vector());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
write, newVector, write.source(), write.indices(), newMap, mask);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||
// TODO: Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
|
@ -2622,6 +2802,15 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
|
||||
CastAwayInsertStridedSliceLeadingOneDim,
|
||||
CastAwayTransferReadLeadingOneDim,
|
||||
CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
|
||||
context);
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
|
||||
|
|
|
@ -601,3 +601,73 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
|
|||
: vector<4x4xf32>, tensor<4x4xf32>
|
||||
return %r : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
|
||||
func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
|
||||
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
|
||||
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
|
||||
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
|
||||
// CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
|
||||
// CHECK: return %[[RET]]
|
||||
return %0: vector<1x1x8xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
|
||||
func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
|
||||
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
|
||||
// CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
|
||||
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
|
||||
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
|
||||
// CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
|
||||
// CHECK: return %[[RET]]
|
||||
return %0: vector<1x8x8xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
|
||||
func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
|
||||
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
|
||||
// CHECK: vector.shape_cast %{{.+}} : vector<1x1x1xf16> to vector<1xf16>
|
||||
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
|
||||
return %0: vector<1x1x1xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
|
||||
func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16
|
||||
%f0 = constant 0. : f16
|
||||
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16>
|
||||
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16>
|
||||
// CHECK: return %[[CAST]]
|
||||
return %0: vector<1x4xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
|
||||
func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
|
||||
%c0 = constant 0 : index
|
||||
%f0 = constant 0. : f16
|
||||
// CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x1x1x1xf16>, vector<1x1xf16>
|
||||
return %0: vector<1x1xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
|
||||
func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
|
||||
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16>
|
||||
|
||||
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
|
||||
func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
|
||||
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
|
|||
}
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||
populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue