forked from OSchip/llvm-project
[mlir] [VectorOps] Improve lowering of extract_strided_slice (and friends like shape_cast)
Using a shuffle for the last recursive step in progressive lowering not only results in much more compact IR, but also more efficient code (since the backend is no longer confused on subvector aliasing for longer vectors). E.g. the following %f = vector.shape_cast %v0: vector<1024xf32> to vector<32x32xf32> yields much better x86-64 code that runs 3x faster than the original. Reviewed By: bkramer, nicolasvasilache Differential Revision: https://reviews.llvm.org/D85482
This commit is contained in:
parent
25e38c3f3c
commit
c3c95b9c80
|
@ -1349,9 +1349,9 @@ private:
|
|||
};
|
||||
|
||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||
/// 1. extractelement + insertelement for the 1-D case
|
||||
/// 2. extract + optional strided_slice + insert for the n-D case.
|
||||
class VectorStridedSliceOpConversion
|
||||
/// 1. express single offset extract as a direct shuffle.
|
||||
/// 2. extract + lower rank strided_slice + insert for the n-D case.
|
||||
class VectorExtractStridedSliceOpConversion
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||
|
@ -1371,21 +1371,34 @@ public:
|
|||
auto loc = op.getLoc();
|
||||
auto elemType = dstType.getElementType();
|
||||
assert(elemType.isSignlessIntOrIndexOrFloat());
|
||||
|
||||
// Single offset can be more efficiently shuffled.
|
||||
if (op.offsets().getValue().size() == 1) {
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
offsets.reserve(size);
|
||||
for (int64_t off = offset, e = offset + size * stride; off < e;
|
||||
off += stride)
|
||||
offsets.push_back(off);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
|
||||
op.vector(),
|
||||
rewriter.getI64ArrayAttr(offsets));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Extract/insert on a lower ranked extract strided slice op.
|
||||
Value zero = rewriter.create<ConstantOp>(loc, elemType,
|
||||
rewriter.getZeroAttr(elemType));
|
||||
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
Value extracted = extractOne(rewriter, loc, op.vector(), off);
|
||||
if (op.offsets().getValue().size() > 1) {
|
||||
extracted = rewriter.create<ExtractStridedSliceOp>(
|
||||
loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
}
|
||||
Value one = extractOne(rewriter, loc, op.vector(), off);
|
||||
Value extracted = rewriter.create<ExtractStridedSliceOp>(
|
||||
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||
}
|
||||
rewriter.replaceOp(op, {res});
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
/// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
|
||||
|
@ -1404,7 +1417,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
patterns.insert<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorStridedSliceOpConversion>(ctx);
|
||||
VectorExtractStridedSliceOpConversion>(ctx);
|
||||
patterns.insert<VectorReductionOpConversion>(
|
||||
ctx, converter, reassociateFPReductions);
|
||||
patterns
|
||||
|
|
|
@ -512,65 +512,38 @@ func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
|
|||
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice1
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float>
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice1(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.vec<4 x float>)
|
||||
// CHECK: %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : !llvm.vec<4 x float>, !llvm.vec<4 x float>
|
||||
// CHECK: llvm.return %[[T0]] : !llvm.vec<2 x float>
|
||||
|
||||
func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
|
||||
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
|
||||
return %0 : vector<2x8xf32>
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice2
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice2(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>)
|
||||
// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vec<8 x float>>
|
||||
// CHECK: llvm.return %[[T4]] : !llvm.array<2 x vec<8 x float>>
|
||||
|
||||
func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
|
||||
%0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
|
||||
return %0 : vector<2x2xf32>
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice3
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>>
|
||||
//
|
||||
// Subvector vector<8xf32> @2
|
||||
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.array<2 x vec<2 x float>>
|
||||
//
|
||||
// Subvector vector<8xf32> @3
|
||||
// CHECK: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float>
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float>
|
||||
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.array<2 x vec<2 x float>>
|
||||
// CHECK-LABEL: llvm.func @extract_strided_slice3(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>)
|
||||
// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>>
|
||||
// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float>
|
||||
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<2 x vec<2 x float>>
|
||||
// CHECK: %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>>
|
||||
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float>
|
||||
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vec<2 x float>>
|
||||
// CHECK: llvm.return %[[T7]] : !llvm.array<2 x vec<2 x float>>
|
||||
|
||||
func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
|
||||
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
|
||||
|
@ -674,15 +647,11 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
|
|||
}
|
||||
// CHECK-LABEL: llvm.func @extract_strides(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.array<3 x vec<3 x float>>)
|
||||
// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>>
|
||||
// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>>
|
||||
// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm.vec<1 x float>
|
||||
// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm.vec<3 x float>
|
||||
// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm.vec<1 x float>
|
||||
// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm.array<1 x vec<1 x float>>
|
||||
// CHECK: llvm.return %[[s8]] : !llvm.array<1 x vec<1 x float>>
|
||||
// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>>
|
||||
// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>>
|
||||
// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2] : !llvm.vec<3 x float>, !llvm.vec<3 x float>
|
||||
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<1 x vec<1 x float>>
|
||||
// CHECK: llvm.return %[[T4]] : !llvm.array<1 x vec<1 x float>>
|
||||
|
||||
// CHECK-LABEL: llvm.func @vector_fma(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.vec<8 x float>, %[[B:.*]]: !llvm.array<2 x vec<4 x float>>)
|
||||
|
|
Loading…
Reference in New Issue