forked from OSchip/llvm-project
[mlir][vector] Add patterns to cast away leading 1-dim
This patch adds patterns to use vector.shape_cast to cast away leading 1-dimensions from a few vector operations. It allows exposing more canonical forms of vector.transfer_read, vector.transfer_write, vector_extract_strided_slice, and vector.insert_strided_slice. With this, we can have more opportunity to cancelling extract/insert ops or forwarding write/read ops. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D95873
This commit is contained in:
parent
2fbbb18c1d
commit
874ce9b80f
|
@ -35,6 +35,15 @@ void populateVectorToVectorCanonicalizationPatterns(
|
||||||
void populateVectorToVectorTransformationPatterns(
|
void populateVectorToVectorTransformationPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||||
|
|
||||||
|
/// Collect a set of leading one dimension removal patterns.
|
||||||
|
///
|
||||||
|
/// These patterns insert vector.shape_cast to remove leading one dimensions
|
||||||
|
/// to expose more canonical forms of read/write/insert/extract operations.
|
||||||
|
/// With them, there are more chances that we can cancel out extract-insert
|
||||||
|
/// pairs or forward write-read pairs.
|
||||||
|
void populateCastAwayVectorLeadingOneDimPatterns(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||||
|
|
||||||
/// Collect a set of vector slices transformation patterns:
|
/// Collect a set of vector slices transformation patterns:
|
||||||
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
|
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
|
||||||
/// Useful for clients that want to express all vector "slices"
|
/// Useful for clients that want to express all vector "slices"
|
||||||
|
|
|
@ -2607,6 +2607,186 @@ struct TransferWriteInsertPattern
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Trims leading one dimensions from `oldType` and returns the result type.
|
||||||
|
// Returns `vector<1xT>` if `oldType` only has one element.
|
||||||
|
static VectorType trimLeadingOneDims(VectorType oldType) {
|
||||||
|
ArrayRef<int64_t> oldShape = oldType.getShape();
|
||||||
|
ArrayRef<int64_t> newShape =
|
||||||
|
oldShape.drop_while([](int64_t dim) { return dim == 1; });
|
||||||
|
// Make sure we have at least 1 dimension per vector type requirements.
|
||||||
|
if (newShape.empty())
|
||||||
|
newShape = oldShape.take_back();
|
||||||
|
return VectorType::get(newShape, oldType.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Casts away leading one dimensions in vector.extract_strided_slice's vector
|
||||||
|
// input by inserting vector.shape_cast.
|
||||||
|
struct CastAwayExtractStridedSliceLeadingOneDim
|
||||||
|
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// vector.extract_strided_slice requires the input and output vector to have
|
||||||
|
// the same rank. Here we drop leading one dimensions from the input vector
|
||||||
|
// type to make sure we don't cause mismatch.
|
||||||
|
VectorType oldSrcType = extractOp.getVectorType();
|
||||||
|
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
|
||||||
|
|
||||||
|
if (newSrcType.getRank() == oldSrcType.getRank())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
|
||||||
|
|
||||||
|
VectorType oldDstType = extractOp.getType();
|
||||||
|
VectorType newDstType =
|
||||||
|
VectorType::get(oldDstType.getShape().drop_front(dropCount),
|
||||||
|
oldDstType.getElementType());
|
||||||
|
|
||||||
|
Location loc = extractOp.getLoc();
|
||||||
|
|
||||||
|
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
|
||||||
|
loc, newSrcType, extractOp.vector());
|
||||||
|
|
||||||
|
// The offsets/sizes/strides attribute can have a less number of elements
|
||||||
|
// than the input vector's rank: it is meant for the leading dimensions.
|
||||||
|
auto newOffsets = rewriter.getArrayAttr(
|
||||||
|
extractOp.offsets().getValue().drop_front(dropCount));
|
||||||
|
auto newSizes = rewriter.getArrayAttr(
|
||||||
|
extractOp.sizes().getValue().drop_front(dropCount));
|
||||||
|
auto newStrides = rewriter.getArrayAttr(
|
||||||
|
extractOp.strides().getValue().drop_front(dropCount));
|
||||||
|
|
||||||
|
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||||
|
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
|
||||||
|
newExtractOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Casts away leading one dimensions in vector.extract_strided_slice's vector
|
||||||
|
// inputs by inserting vector.shape_cast.
|
||||||
|
struct CastAwayInsertStridedSliceLeadingOneDim
|
||||||
|
: public OpRewritePattern<vector::InsertStridedSliceOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
VectorType oldSrcType = insertOp.getSourceVectorType();
|
||||||
|
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
|
||||||
|
VectorType oldDstType = insertOp.getDestVectorType();
|
||||||
|
VectorType newDstType = trimLeadingOneDims(oldDstType);
|
||||||
|
|
||||||
|
if (newSrcType.getRank() == oldSrcType.getRank() &&
|
||||||
|
newDstType.getRank() == oldDstType.getRank())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Trim leading one dimensions from both operands.
|
||||||
|
Location loc = insertOp.getLoc();
|
||||||
|
|
||||||
|
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
|
||||||
|
loc, newSrcType, insertOp.source());
|
||||||
|
Value newDstVector =
|
||||||
|
rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
|
||||||
|
|
||||||
|
auto newOffsets = rewriter.getArrayAttr(
|
||||||
|
insertOp.offsets().getValue().take_back(newDstType.getRank()));
|
||||||
|
auto newStrides = rewriter.getArrayAttr(
|
||||||
|
insertOp.strides().getValue().take_back(newSrcType.getRank()));
|
||||||
|
|
||||||
|
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
|
||||||
|
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
|
||||||
|
newInsertOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turns vector.transfer_read on vector with leading 1 dimensions into
|
||||||
|
// vector.shape_cast followed by vector.transfer_read on vector without leading
|
||||||
|
// 1 dimensions.
|
||||||
|
struct CastAwayTransferReadLeadingOneDim
|
||||||
|
: public OpRewritePattern<vector::TransferReadOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto shapedType = read.source().getType().cast<ShapedType>();
|
||||||
|
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
VectorType oldType = read.getVectorType();
|
||||||
|
VectorType newType = trimLeadingOneDims(oldType);
|
||||||
|
|
||||||
|
if (newType == oldType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
AffineMap oldMap = read.permutation_map();
|
||||||
|
ArrayRef<AffineExpr> newResults =
|
||||||
|
oldMap.getResults().take_back(newType.getRank());
|
||||||
|
AffineMap newMap =
|
||||||
|
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
|
||||||
|
rewriter.getContext());
|
||||||
|
|
||||||
|
ArrayAttr mask;
|
||||||
|
if (read.masked())
|
||||||
|
mask = rewriter.getArrayAttr(
|
||||||
|
read.maskedAttr().getValue().take_back(newType.getRank()));
|
||||||
|
|
||||||
|
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
|
read.getLoc(), newType, read.source(), read.indices(), newMap,
|
||||||
|
read.padding(), mask);
|
||||||
|
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turns vector.transfer_write on vector with leading 1 dimensions into
|
||||||
|
// vector.shape_cast followed by vector.transfer_write on vector without leading
|
||||||
|
// 1 dimensions.
|
||||||
|
struct CastAwayTransferWriteLeadingOneDim
|
||||||
|
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
|
||||||
|
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
VectorType oldType = write.getVectorType();
|
||||||
|
VectorType newType = trimLeadingOneDims(oldType);
|
||||||
|
|
||||||
|
if (newType == oldType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
AffineMap oldMap = write.permutation_map();
|
||||||
|
ArrayRef<AffineExpr> newResults =
|
||||||
|
oldMap.getResults().take_back(newType.getRank());
|
||||||
|
AffineMap newMap =
|
||||||
|
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
|
||||||
|
rewriter.getContext());
|
||||||
|
|
||||||
|
ArrayAttr mask;
|
||||||
|
if (write.masked())
|
||||||
|
mask = rewriter.getArrayAttr(
|
||||||
|
write.maskedAttr().getValue().take_back(newType.getRank()));
|
||||||
|
|
||||||
|
auto newVector = rewriter.create<vector::ShapeCastOp>(
|
||||||
|
write.getLoc(), newType, write.vector());
|
||||||
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||||
|
write, newVector, write.source(), write.indices(), newMap, mask);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||||
// TODO: Add this as DRR pattern.
|
// TODO: Add this as DRR pattern.
|
||||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||||
|
@ -2622,6 +2802,15 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||||
|
patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
|
||||||
|
CastAwayInsertStridedSliceLeadingOneDim,
|
||||||
|
CastAwayTransferReadLeadingOneDim,
|
||||||
|
CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::vector::populateVectorSlicesLoweringPatterns(
|
void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||||
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
|
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
|
||||||
|
|
|
@ -601,3 +601,73 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
|
||||||
: vector<4x4xf32>, tensor<4x4xf32>
|
: vector<4x4xf32>, tensor<4x4xf32>
|
||||||
return %r : tensor<4x4xf32>
|
return %r : tensor<4x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
|
||||||
|
func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
|
||||||
|
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
|
||||||
|
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
|
||||||
|
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
|
||||||
|
// CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
|
||||||
|
// CHECK: return %[[RET]]
|
||||||
|
return %0: vector<1x1x8xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
|
||||||
|
func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
|
||||||
|
// CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
|
||||||
|
// CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
|
||||||
|
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
|
||||||
|
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
|
||||||
|
// CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
|
||||||
|
// CHECK: return %[[RET]]
|
||||||
|
return %0: vector<1x8x8xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
|
||||||
|
func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
|
||||||
|
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
|
||||||
|
// CHECK: vector.shape_cast %{{.+}} : vector<1x1x1xf16> to vector<1xf16>
|
||||||
|
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
|
||||||
|
return %0: vector<1x1x1xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
|
||||||
|
func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
|
||||||
|
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
// CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16
|
||||||
|
%f0 = constant 0. : f16
|
||||||
|
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16>
|
||||||
|
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
|
||||||
|
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16>
|
||||||
|
// CHECK: return %[[CAST]]
|
||||||
|
return %0: vector<1x4xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
|
||||||
|
func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%f0 = constant 0. : f16
|
||||||
|
// CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
|
||||||
|
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x1x1x1xf16>, vector<1x1xf16>
|
||||||
|
return %0: vector<1x1xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
|
||||||
|
func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
|
||||||
|
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
// CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
|
||||||
|
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16>
|
||||||
|
|
||||||
|
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
|
||||||
|
func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
// CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
|
||||||
|
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
|
||||||
}
|
}
|
||||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||||
|
populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue