forked from OSchip/llvm-project
[mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.
The 2-D case can be rewritten to generate quite fewer instructions and a single vector.shuffle which seems to provide a nice performance boost. Add this arrow to our quiver by exposing it with a new vector transform option. Differential Revision: https://reviews.llvm.org/D113062
This commit is contained in:
parent
c964afb2c8
commit
885072820c
|
@ -29,6 +29,8 @@ enum class VectorTransposeLowering {
|
|||
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
|
||||
/// intrinsics.
|
||||
Flat = 1,
|
||||
/// Lower 2-D transpose to `vector.shuffle`.
|
||||
Shuffle = 2,
|
||||
};
|
||||
/// Enum to control the lowering of `vector.multi_reduction` operations.
|
||||
enum class VectorMultiReductionLowering {
|
||||
|
|
|
@ -686,6 +686,12 @@ public:
|
|||
for (auto attr : op.transp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
|
||||
if (vectorTransformOptions.vectorTransposeLowering ==
|
||||
vector::VectorTransposeLowering::Shuffle &&
|
||||
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Options specifies lowering to shuffle");
|
||||
|
||||
// Handle a true 2-D matrix transpose differently when requested.
|
||||
if (vectorTransformOptions.vectorTransposeLowering ==
|
||||
vector::VectorTransposeLowering::Flat &&
|
||||
|
@ -740,6 +746,61 @@ private:
|
|||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
|
||||
/// Rewrite a 2-D vector.transpose as a sequence of:
|
||||
/// vector.shape_cast 2D -> 1D
|
||||
/// vector.shuffle
|
||||
/// vector.shape_cast 1D -> 2D
|
||||
class TransposeOp2DToShuffleLowering
|
||||
: public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
|
||||
|
||||
TransposeOp2DToShuffleLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context)
|
||||
: OpRewritePattern<vector::TransposeOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType srcType = op.getVectorType();
|
||||
if (srcType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
|
||||
|
||||
SmallVector<int64_t, 4> transp;
|
||||
for (auto attr : op.transp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
if (transp[0] != 1 && transp[1] != 0)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
|
||||
|
||||
if (vectorTransformOptions.vectorTransposeLowering !=
|
||||
VectorTransposeLowering::Shuffle)
|
||||
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
|
||||
|
||||
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
|
||||
Value casted = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
|
||||
SmallVector<int64_t> mask;
|
||||
mask.reserve(m * n);
|
||||
for (int64_t j = 0; j < n; ++j)
|
||||
for (int64_t i = 0; i < m; ++i)
|
||||
mask.push_back(i * n + j);
|
||||
|
||||
Value shuffled =
|
||||
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
|
||||
shuffled);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
|
||||
/// Progressive lowering of OuterProductOp.
|
||||
/// One:
|
||||
/// %x = vector.outerproduct %lhs, %rhs, %acc
|
||||
|
@ -3656,7 +3717,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
|||
|
||||
void mlir::vector::populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorTransformsOptions options) {
|
||||
patterns.add<TransposeOpLowering>(options, patterns.getContext());
|
||||
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
|
||||
options, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorReductionToContractPatterns(
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
|
||||
// 0 4
|
||||
// 0 1 2 3 1 5
|
||||
// 4 5 6 7 -> 2 6
|
||||
// 3 7
|
||||
// CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
return %0 : vector<4x2xf32>
|
||||
}
|
|
@ -116,6 +116,10 @@ struct TestVectorContractionConversion
|
|||
*this, "vector-flat-transpose",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToShuffleTranspose{
|
||||
*this, "vector-shuffle-transpose",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToOuterProduct{
|
||||
*this, "vector-outerproduct",
|
||||
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
||||
|
@ -165,12 +169,15 @@ struct TestVectorContractionConversion
|
|||
VectorTransposeLowering::EltWise;
|
||||
if (lowerToFlatTranspose)
|
||||
transposeLowering = VectorTransposeLowering::Flat;
|
||||
if (lowerToShuffleTranspose)
|
||||
transposeLowering = VectorTransposeLowering::Shuffle;
|
||||
VectorTransformsOptions options{
|
||||
contractLowering, vectorMultiReductionLowering, transposeLowering};
|
||||
populateVectorBroadcastLoweringPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns, options);
|
||||
populateVectorMaskOpLoweringPatterns(patterns);
|
||||
populateVectorShapeCastLoweringPatterns(patterns);
|
||||
if (!lowerToShuffleTranspose)
|
||||
populateVectorShapeCastLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns, options);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue