Revert "[MLIR] Update Vector To LLVM conversion to be aware of assume_alignment"

This reverts commit 29a50c5864.

After LLVM lowering, the original patch incorrectly moved alignment
information across an unconstrained GEP operation.  This is only correct
for some index offsets in the GEP.  It seems that the best approach is,
in fact, to rely on LLVM to propagate information from the llvm.assume()
to users.

Thanks to Thomas Raoux for catching this.
This commit is contained in:
Stephen Neuendorffer 2021-11-30 14:37:30 -08:00
parent 9b704d31b5
commit 7386364889
2 changed files with 3 additions and 108 deletions

View File

@ -84,30 +84,6 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
return success();
}
// Return the minimal alignment value that satisfies all the AssumeAlignment
// uses of `value`. If no such uses exist, return 1.
static unsigned getAssumedAlignment(Value value) {
unsigned align = 1;
for (auto &u : value.getUses()) {
Operation *owner = u.getOwner();
if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
align = mlir::lcm(align, op.alignment());
}
return align;
}
// Helper that returns data layout alignment of a memref associated with a
// load, store, scatter, or gather op, including additional information from
// assume_alignment calls on the source of the transfer
template <class OpAdaptor>
LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
OpAdaptor op, unsigned &align) {
if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
return failure();
align = std::max(align, getAssumedAlignment(op.base()));
return success();
}
// Add an index vector component to a base pointer. This almost always succeeds
// unless the last stride is non-unit or the memory space is not zero.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
@ -246,8 +222,7 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
align)))
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();
// Resolve address.
@ -276,7 +251,7 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
@ -310,7 +285,7 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.

View File

@ -1293,26 +1293,6 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
// -----
func @transfer_read_1d_aligned(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
memref.assume_alignment %A, 32 : memref<?xf32>
%f7 = arith.constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7
{permutation_map = affine_map<(d0) -> (d0)>} :
memref<?xf32>, vector<17xf32>
vector.transfer_write %f, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>} :
vector<17xf32>, memref<?xf32>
return %f: vector<17xf32>
}
// CHECK: llvm.intr.masked.load
// CHECK-SAME: {alignment = 32 : i32}
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
// CHECK: llvm.intr.masked.store
// CHECK-SAME: {alignment = 32 : i32}
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr<vector<17xf32>>
// -----
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
%f7 = arith.constant 7.0: f32
%f = vector.transfer_read %A[%base0, %base1], %f7
@ -1485,22 +1465,6 @@ func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : ind
// -----
func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
memref.assume_alignment %memref, 32 : memref<200x100xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @vector_load_op_aligned
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<8xf32>>
// -----
func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
%val = arith.constant dense<11.0> : vector<4xf32>
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
@ -1527,23 +1491,6 @@ func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : in
// -----
func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) {
memref.assume_alignment %memref, 32 : memref<200x100xf32>
%val = arith.constant dense<11.0> : vector<4xf32>
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
return
}
// CHECK-LABEL: func @vector_store_op_aligned
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<4xf32>>
// -----
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
@ -1621,20 +1568,6 @@ func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vec
// -----
func @gather_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
memref.assume_alignment %arg0, 32 : memref<?xf32>
%0 = arith.constant 0: index
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %1 : vector<3xf32>
}
// CHECK-LABEL: func @gather_op_aligned
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>
// -----
func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@ -1673,19 +1606,6 @@ func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: ve
// -----
func @scatter_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
memref.assume_alignment %arg0, 32 : memref<?xf32>
%0 = arith.constant 0: index
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}
// CHECK-LABEL: func @scatter_op_aligned
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
// -----
func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
%0 = arith.constant 3 : index
vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>