From 73863648892ee7063c7fd4e658d7614fd721504a Mon Sep 17 00:00:00 2001 From: Stephen Neuendorffer Date: Tue, 30 Nov 2021 14:37:30 -0800 Subject: [PATCH] Revert "[MLIR] Update Vector To LLVM conversion to be aware of assume_alignment" This reverts commit 29a50c5864ddab283c1ff38694fb5926ce37b39a. After LLVM lowering, the original patch incorrectly moved alignment information across an unconstrained GEP operation. This is only correct for some index offsets in the GEP. It seems that the best approach is, in fact, to rely on LLVM to propagate information from the llvm.assume() to users. Thanks to Thomas Raoux for catching this. --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 31 +------ .../VectorToLLVM/vector-to-llvm.mlir | 80 ------------------- 2 files changed, 3 insertions(+), 108 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5ea34d03bec7..bc42922a4485 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -84,30 +84,6 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, return success(); } -// Return the minimal alignment value that satisfies all the AssumeAlignment -// uses of `value`. If no such uses exist, return 1. -static unsigned getAssumedAlignment(Value value) { - unsigned align = 1; - for (auto &u : value.getUses()) { - Operation *owner = u.getOwner(); - if (auto op = dyn_cast(owner)) - align = mlir::lcm(align, op.alignment()); - } - return align; -} - -// Helper that returns data layout alignment of a memref associated with a -// load, store, scatter, or gather op, including additional information from -// assume_alignment calls on the source of the transfer -template -LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter, - OpAdaptor op, unsigned &align) { - if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align))) - return failure(); - align = std::max(align, getAssumedAlignment(op.base())); - return success(); -} - // Add an index vector component to a base pointer. This almost always succeeds // unless the last stride is non-unit or the memory space is not zero. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, @@ -246,8 +222,7 @@ public: // Resolve alignment. unsigned align; - if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp, - align))) + if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) return failure(); // Resolve address. @@ -276,7 +251,7 @@ public: // Resolve alignment. unsigned align; - if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. @@ -310,7 +285,7 @@ public: // Resolve alignment. unsigned align; - if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 42a264a8aa97..ce81c4e36bb6 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1293,26 +1293,6 @@ func @transfer_read_index_1d(%A : memref, %base: index) -> vector<17xin // ----- -func @transfer_read_1d_aligned(%A : memref, %base: index) -> vector<17xf32> { - memref.assume_alignment %A, 32 : memref - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7 - {permutation_map = affine_map<(d0) -> (d0)>} : - memref, vector<17xf32> - vector.transfer_write %f, %A[%base] - {permutation_map = affine_map<(d0) -> (d0)>} : - vector<17xf32>, memref - return %f: vector<17xf32> -} -// CHECK: llvm.intr.masked.load -// CHECK-SAME: {alignment = 32 : i32} -// CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> -// CHECK: llvm.intr.masked.store -// CHECK-SAME: {alignment = 32 : i32} -// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr> - -// ----- - func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> { %f7 = arith.constant 7.0: f32 %f = vector.transfer_read %A[%base0, %base1], %f7 @@ -1485,22 +1465,6 @@ func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : ind // ----- -func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { - memref.assume_alignment %memref, 32 : memref<200x100xf32> - %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> - return %0 : vector<8xf32> -} - -// CHECK-LABEL: func @vector_load_op_aligned -// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 -// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 -// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 -// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> -// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr> - -// ----- - func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) { %val = arith.constant dense<11.0> : vector<4xf32> vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> @@ -1527,23 +1491,6 @@ func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : in // ----- -func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) { - memref.assume_alignment %memref, 32 : memref<200x100xf32> - %val = arith.constant dense<11.0> : vector<4xf32> - vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> - return -} - -// CHECK-LABEL: func @vector_store_op_aligned -// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 -// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 -// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 -// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> -// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr> - -// ----- - func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = arith.constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> @@ -1621,20 +1568,6 @@ func @gather_op_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: vec // ----- -func @gather_op_aligned(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { - memref.assume_alignment %arg0, 32 : memref - %0 = arith.constant 0: index - %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> - return %1 : vector<3xf32> -} - -// CHECK-LABEL: func @gather_op_aligned -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> -// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> -// CHECK: return %[[G]] : vector<3xf32> - -// ----- - func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = arith.constant 3 : index %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> @@ -1673,19 +1606,6 @@ func @scatter_op_index(%arg0: memref, %arg1: vector<3xindex>, %arg2: ve // ----- -func @scatter_op_aligned(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { - memref.assume_alignment %arg0, 32 : memref - %0 = arith.constant 0: index - vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> - return -} - -// CHECK-LABEL: func @scatter_op_aligned -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> -// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> - -// ----- - func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { %0 = arith.constant 3 : index vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>