forked from OSchip/llvm-project
[mlir][vector] Add unrolling pattern for TransposeOp
Support unrolling for vector.transpose following the same interface as other vector unrolling ops. Differential Revision: https://reviews.llvm.org/D123688
This commit is contained in:
parent
26eec9e9db
commit
5b1b7108c8
|
@ -2217,6 +2217,7 @@ def Vector_CreateMaskOp :
|
|||
|
||||
def Vector_TransposeOp :
|
||||
Vector_Op<"transpose", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
|
||||
|
|
|
@ -4320,6 +4320,10 @@ LogicalResult vector::TransposeOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
|
||||
return llvm::to_vector<4>(getResultType().getShape());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
|
||||
|
|
|
@ -681,14 +681,62 @@ private:
|
|||
const vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
UnrollTranposePattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
|
||||
options(options) {}
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (tranposeOp.getResultType().getRank() == 0)
|
||||
return failure();
|
||||
auto targetShape = getTargetShape(options, tranposeOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
auto originalVectorType = tranposeOp.getResultType();
|
||||
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
|
||||
Location loc = tranposeOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
|
||||
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
// Prepare the result vector;
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
|
||||
SmallVector<int64_t> permutation;
|
||||
tranposeOp.getTransp(permutation);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t, 4> elementOffsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
|
||||
SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
|
||||
// Compute the source offsets and shape.
|
||||
for (auto &indices : llvm::enumerate(permutation)) {
|
||||
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
|
||||
permutedShape[indices.value()] = (*targetShape)[indices.index()];
|
||||
}
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
|
||||
Value tranposedSlice =
|
||||
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
|
||||
result = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, tranposedSlice, result, elementOffsets, strides);
|
||||
}
|
||||
rewriter.replaceOp(tranposeOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorUnrollPatterns(
|
||||
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
|
||||
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
|
||||
UnrollContractionPattern, UnrollElementwisePattern,
|
||||
UnrollReductionPattern, UnrollMultiReductionPattern>(
|
||||
patterns.getContext(), options);
|
||||
UnrollReductionPattern, UnrollMultiReductionPattern,
|
||||
UnrollTranposePattern>(patterns.getContext(), options);
|
||||
}
|
||||
|
||||
void mlir::vector::populatePropagateVectorDistributionPatterns(
|
||||
|
|
|
@ -107,6 +107,11 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
|
|||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: return %[[V2]] : vector<4xf32>
|
||||
|
||||
|
||||
func @vector_reduction(%v : vector<8xf32>) -> f32 {
|
||||
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @vector_reduction(
|
||||
// CHECK-SAME: %[[v:.*]]: vector<8xf32>
|
||||
// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
|
||||
|
@ -121,8 +126,35 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
|
|||
// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
|
||||
// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
|
||||
// CHECK: return %[[add3]]
|
||||
func @vector_reduction(%v : vector<8xf32>) -> f32 {
|
||||
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
|
||||
%t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
|
||||
return %t : vector<2x3x8x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_tranpose
|
||||
// CHECK: %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
|
||||
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
|
||||
// CHECK: return %[[V7]] : vector<2x3x8x4xf32>
|
||||
|
|
|
@ -282,6 +282,12 @@ struct TestVectorUnrollingPatterns
|
|||
.setFilterConstraint([](Operation *op) {
|
||||
return success(isa<vector::ReductionOp>(op));
|
||||
}));
|
||||
populateVectorUnrollPatterns(
|
||||
patterns, UnrollVectorOptions()
|
||||
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
|
||||
.setFilterConstraint([](Operation *op) {
|
||||
return success(isa<vector::TransposeOp>(op));
|
||||
}));
|
||||
|
||||
if (unrollBasedOnType) {
|
||||
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
|
||||
|
|
Loading…
Reference in New Issue