[mlir] Require struct indices in LLVM::GEPOp to be constant

Recent commits added a possibility for indices in LLVM dialect GEP operations
to be supplied directly as constant attributes to ensure they remain such until
translation to LLVM IR happens. Make this required for indexing into LLVM
struct types to match LLVM IR requirements, otherwise the translation would
assert on constructing such IR.

For better compatibility with MLIR-style operation construction interface,
allow GEP operations to be constructed programmatically using Values pointing
to known constant operations as struct indices.

Depends On D116758

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116759
This commit is contained in:
Alex Zinenko 2022-01-06 23:30:15 +01:00
parent 43ff4a6d55
commit f50cfc44d6
7 changed files with 146 additions and 9 deletions

View File

@ -350,6 +350,9 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
}];
let hasFolder = 1;
let verifier = [{
return ::verify(*this);
}];
}
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {

View File

@ -360,6 +360,58 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
// Code for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
/// Populates `indices` with positions of GEP indices that would correspond to
/// LLVMStructTypes potentially nested in the given type. The type currently
/// visited gets `currentIndex` and LLVM container types are visited
/// recursively. The recursion is bounded and takes care of recursive types by
/// means of the `visited` set.
static void recordStructIndices(Type type, unsigned currentIndex,
SmallVectorImpl<unsigned> &indices,
SmallVectorImpl<unsigned> *structSizes,
SmallPtrSet<Type, 4> &visited) {
if (visited.contains(type))
return;
visited.insert(type);
llvm::TypeSwitch<Type>(type)
.Case<LLVMStructType>([&](LLVMStructType structType) {
indices.push_back(currentIndex);
if (structSizes)
structSizes->push_back(structType.getBody().size());
for (Type elementType : structType.getBody())
recordStructIndices(elementType, currentIndex + 1, indices,
structSizes, visited);
})
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
LLVMArrayType>([&](auto containerType) {
recordStructIndices(containerType.getElementType(), currentIndex + 1,
indices, structSizes, visited);
});
}
/// Populates `indices` with positions of GEP indices that correspond to
/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
/// be either an LLVMPointer type or a vector thereof. If `structSizes` is
/// provided, it is populated with sizes of the indexed structs for bounds
/// verification purposes.
static void
findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
SmallVectorImpl<unsigned> *structSizes = nullptr) {
Type type = baseGEPType;
if (auto vectorType = type.dyn_cast<VectorType>())
type = vectorType.getElementType();
if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
type = scalableVectorType.getElementType();
if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
type = fixedVectorType.getElementType();
Type pointeeType = type.cast<LLVMPointerType>().getElementType();
SmallPtrSet<Type, 4> visited;
recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
visited);
}
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Value basePtr, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
@ -372,11 +424,58 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Value basePtr, ValueRange indices,
ArrayRef<int32_t> structIndices,
ArrayRef<NamedAttribute> attributes) {
SmallVector<Value> remainingIndices;
SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
structIndices.end());
SmallVector<unsigned> structRelatedPositions;
findKnownStructIndices(basePtr.getType(), structRelatedPositions);
SmallVector<unsigned> operandsToErase;
for (unsigned pos : structRelatedPositions) {
// GEP may not be indexing as deep as some structs are located.
if (pos >= structIndices.size())
continue;
// If the index is already static, it's fine.
if (structIndices[pos] != kDynamicIndex)
continue;
// Find the corresponding operand.
unsigned operandPos =
std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
kDynamicIndex);
// Extract the constant value from the operand and put it into the attribute
// instead.
APInt staticIndexValue;
bool matched =
matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
(void)matched;
assert(matched && "index into a struct must be a constant");
assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
"struct index underflows 32-bit integer");
assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
"struct index overflows 32-bit integer");
auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
updatedStructIndices[pos] = staticIndex;
operandsToErase.push_back(operandPos);
}
for (unsigned i = 0, e = indices.size(); i < e; ++i) {
if (llvm::find(operandsToErase, i) == operandsToErase.end())
remainingIndices.push_back(indices[i]);
}
assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
updatedStructIndices, kDynamicIndex)) &&
"exected as many index operands as dynamic index attr elements");
result.addTypes(resultType);
result.addAttributes(attributes);
result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
result.addAttribute("structIndices",
builder.getI32TensorAttr(updatedStructIndices));
result.addOperands(basePtr);
result.addOperands(indices);
result.addOperands(remainingIndices);
}
static ParseResult
@ -417,6 +516,27 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
});
}
LogicalResult verify(LLVM::GEPOp gepOp) {
SmallVector<unsigned> indices;
SmallVector<unsigned> structSizes;
findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
for (unsigned i = 0, e = indices.size(); i < e; ++i) {
unsigned index = indices[i];
// GEP may not be indexing as deep as some structs nested in the type.
if (index >= gepOp.getStructIndices().getNumElements())
continue;
int32_t staticIndex = gepOp.getStructIndices().getValues<int32_t>()[index];
if (staticIndex == LLVM::GEPOp::kDynamicIndex)
return gepOp.emitOpError() << "expected index " << index
<< " indexing a struct to be constant";
if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
return gepOp.emitOpError()
<< "index " << index << " indexing a struct is out of bounds";
}
return success();
}
//===----------------------------------------------------------------------===//
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//

View File

@ -501,8 +501,7 @@ func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
// CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, i64)>>
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], 3]
// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : i64

View File

@ -547,12 +547,11 @@ func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
// CHECK: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
// CHECK-SAME: : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
// CHECK: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
// CHECK-SAME: %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
// CHECK-SAME: i64)>>, i64, i32) -> !llvm.ptr<i64>
// CHECK-SAME: %[[C0_]], 2] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
// CHECK-SAME: i64)>>, i64) -> !llvm.ptr<i64>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %{{.*}} : i64

View File

@ -10,7 +10,7 @@ spv.func @access_chain() "None" {
%0 = spv.Constant 1: i32
%1 = spv.Variable : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32, i32) -> !llvm.ptr<f32>
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32) -> !llvm.ptr<f32>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>, i32, i32
spv.Return
}

View File

@ -1234,3 +1234,19 @@ func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
nvvm.cp.async.shared.global %arg0, %arg1, 32
return
}
// -----
func @gep_struct_variable(%arg0: !llvm.ptr<struct<(i32)>>, %arg1: i32, %arg2: i32) {
// expected-error @below {{op expected index 1 indexing a struct to be constant}}
llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr<struct<(i32)>>, i32, i32) -> !llvm.ptr<i32>
return
}
// -----
func @gep_out_of_bounds(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64) {
// expected-error @below {{index 2 indexing a struct is out of bounds}}
llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
return
}

View File

@ -1444,7 +1444,7 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
%z32 = llvm.mlir.constant(0 : i32) : i32
%0 = llvm.mlir.undef : !llvm.struct<(i32, !llvm.ptr<i32>)>
%1 = llvm.mlir.addressof @take_self_address : !llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>
%2 = llvm.getelementptr %1[%z32, %z32] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32, i32) -> !llvm.ptr<i32>
%2 = llvm.getelementptr %1[%z32, 0] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32) -> !llvm.ptr<i32>
%3 = llvm.insertvalue %z32, %0[0 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
%4 = llvm.insertvalue %2, %3[1 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
llvm.return %4 : !llvm.struct<(i32, !llvm.ptr<i32>)>