From c3c95b9c808519662afe8b9053aa88b5be451d1d Mon Sep 17 00:00:00 2001 From: aartbik Date: Thu, 6 Aug 2020 15:34:47 -0700 Subject: [PATCH] [mlir] [VectorOps] Improve lowering of extract_strided_slice (and friends like shape_cast) Using a shuffle for the last recursive step in progressive lowering not only results in much more compact IR, but also more efficient code (since the backend is no longer confused on subvector aliasing for longer vectors). E.g. the following %f = vector.shape_cast %v0: vector<1024xf32> to vector<32x32xf32> yields much better x86-64 code that runs 3x faster than the original. Reviewed By: bkramer, nicolasvasilache Differential Revision: https://reviews.llvm.org/D85482 --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 37 +++++--- .../VectorToLLVM/vector-to-llvm.mlir | 85 ++++++------------- 2 files changed, 52 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 12d3f5042bcd..1e92b80d830f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1349,9 +1349,9 @@ private: }; /// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. extractelement + insertelement for the 1-D case -/// 2. extract + optional strided_slice + insert for the n-D case. -class VectorStridedSliceOpConversion +/// 1. express single offset extract as a direct shuffle. +/// 2. extract + lower rank strided_slice + insert for the n-D case. +class VectorExtractStridedSliceOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1371,21 +1371,34 @@ public: auto loc = op.getLoc(); auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); + + // Single offset can be more efficiently shuffled. + if (op.offsets().getValue().size() == 1) { + SmallVector offsets; + offsets.reserve(size); + for (int64_t off = offset, e = offset + size * stride; off < e; + off += stride) + offsets.push_back(off); + rewriter.replaceOpWithNewOp(op, dstType, op.vector(), + op.vector(), + rewriter.getI64ArrayAttr(offsets)); + return success(); + } + + // Extract/insert on a lower ranked extract strided slice op. Value zero = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value extracted = extractOne(rewriter, loc, op.vector(), off); - if (op.offsets().getValue().size() > 1) { - extracted = rewriter.create( - loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.sizes(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - } + Value one = extractOne(rewriter, loc, op.vector(), off); + Value extracted = rewriter.create( + loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.sizes(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); res = insertOne(rewriter, loc, extracted, res, idx); } - rewriter.replaceOp(op, {res}); + rewriter.replaceOp(op, res); return success(); } /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is @@ -1404,7 +1417,7 @@ void mlir::populateVectorToLLVMConversionPatterns( patterns.insert(ctx); + VectorExtractStridedSliceOpConversion>(ctx); patterns.insert( ctx, converter, reassociateFPReductions); patterns diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 5254d2eef4bf..d91d4db06106 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -512,65 +512,38 @@ func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> return %0 : vector<2xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice1 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float> +// CHECK-LABEL: llvm.func @extract_strided_slice1( +// CHECK-SAME: %[[A:.*]]: !llvm.vec<4 x float>) +// CHECK: %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : !llvm.vec<4 x float>, !llvm.vec<4 x float> +// CHECK: llvm.return %[[T0]] : !llvm.vec<2 x float> func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> return %0 : vector<2x8xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice2 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm.array<2 x vec<8 x float>> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.array<2 x vec<8 x float>> -// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.array<2 x vec<8 x float>> +// CHECK-LABEL: llvm.func @extract_strided_slice2( +// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>) +// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vec<8 x float>> +// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vec<8 x float>> +// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vec<8 x float>> +// CHECK: llvm.return %[[T4]] : !llvm.array<2 x vec<8 x float>> func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> return %0 : vector<2x2xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice3 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>> -// -// Subvector vector<8xf32> @2 -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.array<2 x vec<2 x float>> -// -// Subvector vector<8xf32> @3 -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.array<2 x vec<2 x float>> +// CHECK-LABEL: llvm.func @extract_strided_slice3( +// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>) +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<2 x vec<2 x float>> +// CHECK: %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vec<2 x float>> +// CHECK: llvm.return %[[T7]] : !llvm.array<2 x vec<2 x float>> func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> @@ -674,15 +647,11 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { } // CHECK-LABEL: llvm.func @extract_strides( // CHECK-SAME: %[[A:.*]]: !llvm.array<3 x vec<3 x float>>) -// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>> -// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>> -// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm.vec<1 x float> -// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm.vec<3 x float> -// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm.vec<1 x float> -// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm.array<1 x vec<1 x float>> -// CHECK: llvm.return %[[s8]] : !llvm.array<1 x vec<1 x float>> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>> +// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2] : !llvm.vec<3 x float>, !llvm.vec<3 x float> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<1 x vec<1 x float>> +// CHECK: llvm.return %[[T4]] : !llvm.array<1 x vec<1 x float>> // CHECK-LABEL: llvm.func @vector_fma( // CHECK-SAME: %[[A:.*]]: !llvm.vec<8 x float>, %[[B:.*]]: !llvm.array<2 x vec<4 x float>>)