[mlir][Vector] Add a pattern to lower 2-D vector.transpose to shape_cast+shuffle.

The 2-D case can be rewritten to generate quite fewer instructions and a single vector.shuffle which seems to provide a nice performance boost.
Add this arrow to our quiver by exposing it with a new vector transform option.

Differential Revision: https://reviews.llvm.org/D113062
This commit is contained in:
Nicolas Vasilache 2021-11-02 21:59:55 +00:00
parent c964afb2c8
commit 885072820c
4 changed files with 87 additions and 2 deletions

View File

@ -29,6 +29,8 @@ enum class VectorTransposeLowering {
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
/// Lower 2-D transpose to `vector.shuffle`.
Shuffle = 2,
};
/// Enum to control the lowering of `vector.multi_reduction` operations.
enum class VectorMultiReductionLowering {

View File

@ -686,6 +686,12 @@ public:
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Shuffle &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
@ -740,6 +746,61 @@ private:
vector::VectorTransformsOptions vectorTransformOptions;
};
/// Rewrite a 2-D vector.transpose as a sequence of:
/// vector.shape_cast 2D -> 1D
/// vector.shuffle
/// vector.shape_cast 1D -> 2D
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
TransposeOp2DToShuffleLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context)
: OpRewritePattern<vector::TransposeOp>(context),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType srcType = op.getVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
SmallVector<int64_t, 4> transp;
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
if (vectorTransformOptions.vectorTransposeLowering !=
VectorTransposeLowering::Shuffle)
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
Value casted = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
SmallVector<int64_t> mask;
mask.reserve(m * n);
for (int64_t j = 0; j < n; ++j)
for (int64_t i = 0; i < m; ++i)
mask.push_back(i * n + j);
Value shuffled =
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
shuffled);
return success();
}
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
};
/// Progressive lowering of OuterProductOp.
/// One:
/// %x = vector.outerproduct %lhs, %rhs, %acc
@ -3656,7 +3717,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options) {
patterns.add<TransposeOpLowering>(options, patterns.getContext());
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext());
}
void mlir::vector::populateVectorReductionToContractPatterns(

View File

@ -0,0 +1,14 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
// CHECK-LABEL: func @transpose
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
// CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
// 0 4
// 0 1 2 3 1 5
// 4 5 6 7 -> 2 6
// 3 7
// CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
return %0 : vector<4x2xf32>
}

View File

@ -116,6 +116,10 @@ struct TestVectorContractionConversion
*this, "vector-flat-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "vector-shuffle-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@ -165,12 +169,15 @@ struct TestVectorContractionConversion
VectorTransposeLowering::EltWise;
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
if (lowerToShuffleTranspose)
transposeLowering = VectorTransposeLowering::Shuffle;
VectorTransformsOptions options{
contractLowering, vectorMultiReductionLowering, transposeLowering};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
if (!lowerToShuffleTranspose)
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}