[mlir] Make it possible to directly supply constant values to LLVM GEPOp

In LLVM IR, the GEP indices that correspond to structures are required to be
i32 constants. MLIR models constants as just values defined by special
operations, and there is no verification that it is the case for structure
indices in GEP. Furthermore, some common transformations such as control flow
simplification may lead to the operands becoming non-constant. Make it possible
to directly supply constant values to LLVM GEPOp to guarantee they remain
constant until the translation to LLVM IR. This is not yet a requirement and
the verifier is not modified, this will be introduced separately.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116757
This commit is contained in:
Alex Zinenko 2022-01-06 23:29:15 +01:00
parent 348bc76e35
commit cafaa35036
8 changed files with 122 additions and 17 deletions

View File

@ -315,17 +315,39 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
let printer = [{ printAllocaOp(p, *this); }];
}
def LLVM_GEPOp
: LLVM_Op<"getelementptr", [NoSideEffect]>,
LLVM_Builder<
"$res = builder.CreateGEP("
" $base->getType()->getPointerElementType(), $base, $indices);"> {
def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices);
Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices,
I32ElementsAttr:$structIndices);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
let builders = [LLVM_OneResultOpBuilder];
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
"ArrayRef<int32_t>":$structIndices,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
];
let llvmBuilder = [{
SmallVector<llvm::Value *> indices;
indices.reserve($structIndices.size());
unsigned operandIdx = 0;
for (int32_t structIndex : $structIndices.getValues<int32_t>()) {
if (structIndex == GEPOp::kDynamicIndex)
indices.push_back($indices[operandIdx++]);
else
indices.push_back(builder.getInt32(structIndex));
}
$res = builder.CreateGEP(
$base->getType()->getPointerElementType(), $base, indices);
}];
let assemblyFormat = [{
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
$base `[` custom<GEPIndices>($indices, $structIndices) `]` attr-dict
`:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
}];
let hasFolder = 1;
}

View File

@ -790,8 +790,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
Value gepPtr = rewriter.create<LLVM::GEPOp>(
loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
ArrayRef<Value>{numElements});
auto sizeBytes =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);

View File

@ -162,8 +162,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Buffer size in bytes.
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
Value gepPtr = rewriter.create<LLVM::GEPOp>(
loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
ArrayRef<Value>{runningStride});
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
}
@ -178,8 +178,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
LLVM::LLVMPointerType::get(typeConverter->convertType(type));
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
loc, convertedPtrType, nullPtr,
ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
}

View File

@ -497,10 +497,11 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
Type elementType = typeConverter->convertType(type.getElementType());
Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
SmallVector<Value, 4> operands = {addressOf};
SmallVector<Value> operands;
operands.insert(operands.end(), type.getRank() + 1,
createIndexConstant(rewriter, loc, 0));
auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
auto gep =
rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
// We do not expect the memref obtained using `memref.get_global` to be
// ever deallocated. Set the allocated pointer to be known bad value to

View File

@ -356,6 +356,67 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
: getCaseOperandsMutable(index - 1);
}
//===----------------------------------------------------------------------===//
// Code for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Value basePtr, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultType, basePtr, operands,
SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex),
attributes);
}
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Value basePtr, ValueRange indices,
ArrayRef<int32_t> structIndices,
ArrayRef<NamedAttribute> attributes) {
result.addTypes(resultType);
result.addAttributes(attributes);
result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
result.addOperands(basePtr);
result.addOperands(indices);
}
static ParseResult
parseGEPIndices(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &indices,
DenseIntElementsAttr &structIndices) {
SmallVector<int32_t> constantIndices;
do {
int32_t constantIndex;
OptionalParseResult parsedInteger =
parser.parseOptionalInteger(constantIndex);
if (parsedInteger.hasValue()) {
if (failed(parsedInteger.getValue()))
return failure();
constantIndices.push_back(constantIndex);
continue;
}
constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
if (failed(parser.parseOperand(indices.emplace_back())))
return failure();
} while (succeeded(parser.parseOptionalComma()));
structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
return success();
}
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
OperandRange indices,
DenseIntElementsAttr structIndices) {
unsigned operandIdx = 0;
llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
[&](int32_t cst) {
if (cst == LLVM::GEPOp::kDynamicIndex)
printer.printOperand(indices[operandIdx++]);
else
printer << cst;
});
}
//===----------------------------------------------------------------------===//
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//

View File

@ -760,7 +760,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
Type type = processType(inst->getType());
if (!type)
return failure();
v = b.create<GEPOp>(loc, type, ops);
v = b.create<GEPOp>(loc, type, ops[0],
llvm::makeArrayRef(ops).drop_front());
return success();
}
}

View File

@ -170,6 +170,16 @@ func @ops(%arg0: i32, %arg1: f32,
llvm.return
}
// CHECK-LABEL: @gep
llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
%ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
// CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
// CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
llvm.return
}
// An larger self-contained function.
// CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> {
llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> {

View File

@ -975,6 +975,16 @@ llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(
llvm.return %10 : !llvm.struct<(f32, i32)>
}
// CHECK-LABEL: @gep
llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
%ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
// CHECK: = getelementptr { i32, { i32, float } }, { i32, { i32, float } }* %{{.*}}, i64 %{{.*}}, i32 1, i32 0
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
// CHECK: = getelementptr { [10 x float] }, { [10 x float] }* %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10xf32>)>>, i64, i64) -> !llvm.ptr<f32>
llvm.return
}
//
// Indirect function calls
//