[mlir][vector] Add pattern to shuffle bitcast ops

These patterns move vector.bitcast ops to be before
insert ops or after extract ops where suitable.
With them, bitcast will happen on smaller vectors
and there are more chances to share extract/insert
ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96040
This commit is contained in:
Lei Zhang 2021-02-05 17:48:09 -05:00
parent a4fa667dee
commit 7630520ae3
4 changed files with 343 additions and 0 deletions

View File

@ -44,6 +44,14 @@ void populateVectorToVectorTransformationPatterns(
void populateCastAwayVectorLeadingOneDimPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
/// extract ops where suitable. With them, bitcast will happen on smaller
/// vectors and there are more chances to share extract/insert ops.
void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
/// Collect a set of vector slices transformation patterns:
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
/// Useful for clients that want to express all vector "slices"

View File

@ -2787,6 +2787,244 @@ struct CastAwayTransferWriteLeadingOneDim
}
};
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return attr.getInt(); }));
};
// Shuffles vector.bitcast op after vector.extract op.
//
// This transforms IR like:
// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
// %1 = vector.extract %0[3] : vector<8xf16>
// Into:
// %0 = vector.extract %src[1] : vector<4xf32>
// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
// %2 = vector.extract %1[1] : vector<2xf16>
struct BubbleDownVectorBitCastForExtract
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Only support extracting scalars for now.
if (extractOp.getVectorType().getRank() != 1)
return failure();
auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
// Fail to match if we only have one element in the cast op source.
// This is to avoid infinite loop given that this pattern can generate
// such cases.
if (castSrcType.getNumElements() == 1)
return failure();
// Only support casting to a larger number of elements or now.
// E.g., vector<4xf32> -> vector<8xf16>.
if (castSrcType.getNumElements() > castDstType.getNumElements())
return failure();
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
};
uint64_t index = getFirstIntValue(extractOp.position());
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
VectorType oneScalarType =
VectorType::get({1}, castSrcType.getElementType());
Value packedValue = rewriter.create<vector::ExtractOp>(
extractOp.getLoc(), oneScalarType, castOp.source(),
rewriter.getI64ArrayAttr(index / expandRatio));
// Cast it to a vector with the desired scalar's type.
// E.g. f32 -> vector<2xf16>
VectorType packedType =
VectorType::get({expandRatio}, castDstType.getElementType());
Value castedValue = rewriter.create<vector::BitCastOp>(
extractOp.getLoc(), packedType, packedValue);
// Finally extract the desired scalar.
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getType(), castedValue,
rewriter.getI64ArrayAttr(index % expandRatio));
return success();
}
};
// Shuffles vector.bitcast op after vector.extract_strided_slice op.
//
// This transforms IR like:
// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
// %0 = vector.extract_strided_slice %cast {
// offsets = [4], sizes = [4], strides = [1]
// } : vector<8xf16> to vector<4xf16>
// Into:
// %0 = vector.extract_strided_slice %src {
// offsets = [2], sizes = [2], strides = [1]
// } : vector<4xf32> to vector<2xf32>
// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
struct BubbleDownBitCastForStridedSliceExtract
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
// Require casting to more elements for now; other cases to be implemented.
if (castSrcLastDim > castDstLastDim)
return failure();
// Only accept all one strides for now.
if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
unsigned rank = extractOp.getVectorType().getRank();
assert(castDstLastDim % castSrcLastDim == 0);
int64_t expandRatio = castDstLastDim / castSrcLastDim;
// If we have a less number of offsets than the rank, then implicitly we
// are selecting the full range for the last bitcasted dimension; other
// dimensions aren't affected. Otherwise, we need to scale down the last
// dimension's offset given we are extracting from less elements now.
ArrayAttr newOffsets = extractOp.offsets();
if (newOffsets.size() == rank) {
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % expandRatio != 0)
return failure();
offsets.back() = offsets.back() / expandRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
}
// Similarly for sizes.
ArrayAttr newSizes = extractOp.sizes();
if (newSizes.size() == rank) {
SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
if (sizes.back() % expandRatio != 0)
return failure();
sizes.back() = sizes.back() / expandRatio;
newSizes = rewriter.getI64ArrayAttr(sizes);
}
SmallVector<int64_t, 4> dims =
llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
newSizes, extractOp.strides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
return success();
}
};
// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
// %0 = vector.insert_strided_slice %src, %dst {
// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
// %2 = vector.insert_strided_slice %src, %dst {
// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
// Require casting to less elements for now; other cases to be implemented.
if (castSrcLastDim < castDstLastDim)
return failure();
assert(castSrcLastDim % castDstLastDim == 0);
int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
auto insertOp =
bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
if (!insertOp)
return failure();
// Only accept all one strides for now.
if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
unsigned rank = insertOp.getSourceVectorType().getRank();
// Require insert op to have the same rank for the source and destination
// vector; other cases to be implemented.
if (rank != insertOp.getDestVectorType().getRank())
return failure();
ArrayAttr newOffsets = insertOp.offsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % shrinkRatio != 0)
return failure();
offsets.back() = offsets.back() / shrinkRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
SmallVector<int64_t, 4> srcDims =
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
srcDims.back() = srcDims.back() / shrinkRatio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.source());
SmallVector<int64_t, 4> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
dstDims.back() = dstDims.back() / shrinkRatio;
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
auto newCastDstOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastDstType, insertOp.dest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
insertOp.strides());
return success();
}
};
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
@ -2811,6 +3049,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
context);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
BubbleUpBitCastForStridedSliceInsert>(context);
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);

View File

@ -671,3 +671,92 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
return
}
// CHECK-LABEL: func @bubble_down_bitcast_in_extract
// CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
%0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
// CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
%1 = vector.extract %0[3] : vector<8xf16>
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
// CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
%2 = vector.extract %0[4] : vector<8xf16>
// CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
return %1, %2: f16, f16
}
// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract
// CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> {
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16>
%cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
%0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: return %[[CAST]]
return %0: vector<4xf16>
}
// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim
// CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32>
func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> {
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16>
%cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16>
%0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<2x4xf16>
}
// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset
func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> {
// CHECK: vector.bitcast
// CHECK-NEXT: vector.extract_strided_slice
%cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
%0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
return %0: vector<4xf16>
}
// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size
func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> {
// CHECK: vector.bitcast
// CHECK-NEXT: vector.extract_strided_slice
%cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
%0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16>
return %0: vector<3xf16>
}
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
// CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32>
// CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32>
// CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32>
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
%0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%1 = vector.insert_strided_slice %src2, %0 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%cast = vector.bitcast %1: vector<8xf16> to vector<4xf32>
// CHECK: return %[[INSERT2]]
return %cast: vector<4xf32>
}
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset
func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> {
// CHECK: vector.insert_strided_slice
// CHECK-NEXT: vector.bitcast
%0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16>
%cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
return %cast: vector<4xf32>
}
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank
func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> {
// CHECK: vector.insert_strided_slice
// CHECK-NEXT: vector.bitcast
%0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16>
%cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
return %cast: vector<16x4x4xf32>
}

View File

@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
}
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
populateBubbleVectorBitCastOpPatterns(patterns, ctx);
populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}