forked from OSchip/llvm-project
Revert "[MLIR] Update Vector To LLVM conversion to be aware of assume_alignment"
This reverts commit 29a50c5864
.
After LLVM lowering, the original patch incorrectly moved alignment
information across an unconstrained GEP operation. This is only correct
for some index offsets in the GEP. It seems that the best approach is,
in fact, to rely on LLVM to propagate information from the llvm.assume()
to users.
Thanks to Thomas Raoux for catching this.
This commit is contained in:
parent
9b704d31b5
commit
7386364889
|
@ -84,30 +84,6 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Return the minimal alignment value that satisfies all the AssumeAlignment
|
||||
// uses of `value`. If no such uses exist, return 1.
|
||||
static unsigned getAssumedAlignment(Value value) {
|
||||
unsigned align = 1;
|
||||
for (auto &u : value.getUses()) {
|
||||
Operation *owner = u.getOwner();
|
||||
if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
|
||||
align = mlir::lcm(align, op.alignment());
|
||||
}
|
||||
return align;
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref associated with a
|
||||
// load, store, scatter, or gather op, including additional information from
|
||||
// assume_alignment calls on the source of the transfer
|
||||
template <class OpAdaptor>
|
||||
LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
|
||||
OpAdaptor op, unsigned &align) {
|
||||
if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
|
||||
return failure();
|
||||
align = std::max(align, getAssumedAlignment(op.base()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add an index vector component to a base pointer. This almost always succeeds
|
||||
// unless the last stride is non-unit or the memory space is not zero.
|
||||
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
|
@ -246,8 +222,7 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
|
||||
align)))
|
||||
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
|
@ -276,7 +251,7 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
|
@ -310,7 +285,7 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
|
|
|
@ -1293,26 +1293,6 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
|
|||
|
||||
// -----
|
||||
|
||||
func @transfer_read_1d_aligned(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
||||
memref.assume_alignment %A, 32 : memref<?xf32>
|
||||
%f7 = arith.constant 7.0: f32
|
||||
%f = vector.transfer_read %A[%base], %f7
|
||||
{permutation_map = affine_map<(d0) -> (d0)>} :
|
||||
memref<?xf32>, vector<17xf32>
|
||||
vector.transfer_write %f, %A[%base]
|
||||
{permutation_map = affine_map<(d0) -> (d0)>} :
|
||||
vector<17xf32>, memref<?xf32>
|
||||
return %f: vector<17xf32>
|
||||
}
|
||||
// CHECK: llvm.intr.masked.load
|
||||
// CHECK-SAME: {alignment = 32 : i32}
|
||||
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
|
||||
// CHECK: llvm.intr.masked.store
|
||||
// CHECK-SAME: {alignment = 32 : i32}
|
||||
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr<vector<17xf32>>
|
||||
|
||||
// -----
|
||||
|
||||
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
|
||||
%f7 = arith.constant 7.0: f32
|
||||
%f = vector.transfer_read %A[%base0, %base1], %f7
|
||||
|
@ -1485,22 +1465,6 @@ func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : ind
|
|||
|
||||
// -----
|
||||
|
||||
func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
|
||||
memref.assume_alignment %memref, 32 : memref<200x100xf32>
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_load_op_aligned
|
||||
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
|
||||
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
|
||||
// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<8xf32>>
|
||||
|
||||
// -----
|
||||
|
||||
func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
|
||||
%val = arith.constant dense<11.0> : vector<4xf32>
|
||||
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
|
||||
|
@ -1527,23 +1491,6 @@ func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : in
|
|||
|
||||
// -----
|
||||
|
||||
func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) {
|
||||
memref.assume_alignment %memref, 32 : memref<200x100xf32>
|
||||
%val = arith.constant dense<11.0> : vector<4xf32>
|
||||
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_store_op_aligned
|
||||
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
|
||||
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
|
||||
// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<4xf32>>
|
||||
|
||||
// -----
|
||||
|
||||
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = arith.constant 0: index
|
||||
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
|
@ -1621,20 +1568,6 @@ func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vec
|
|||
|
||||
// -----
|
||||
|
||||
func @gather_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
|
||||
memref.assume_alignment %arg0, 32 : memref<?xf32>
|
||||
%0 = arith.constant 0: index
|
||||
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
|
||||
return %1 : vector<3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @gather_op_aligned
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
|
||||
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
|
||||
// CHECK: return %[[G]] : vector<3xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
|
||||
%0 = arith.constant 3 : index
|
||||
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
|
||||
|
@ -1673,19 +1606,6 @@ func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: ve
|
|||
|
||||
// -----
|
||||
|
||||
func @scatter_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
|
||||
memref.assume_alignment %arg0, 32 : memref<?xf32>
|
||||
%0 = arith.constant 0: index
|
||||
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @scatter_op_aligned
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
|
||||
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
|
||||
%0 = arith.constant 3 : index
|
||||
vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
|
||||
|
|
Loading…
Reference in New Issue