forked from OSchip/llvm-project
LLVM dialect: introduce llvm.addressof to access globals
This instruction is a local counterpart of llvm.global that takes a symbol reference to a global and produces an SSA value containing the pointer to it. Used in combination, these two operations allow one to use globals with other operations expecting SSA values. At a cost of IR indirection, we make sure the functions don't implicitly capture the surrounding SSA values and remain suitable for parallel processing. PiperOrigin-RevId: 262908622
This commit is contained in:
parent
252ada4932
commit
2dd38b09c1
|
@ -255,6 +255,27 @@ Selection: `select <condition>, <lhs>, <rhs>`.
|
|||
These operations do not have LLVM IR counterparts but are necessary to map LLVM
|
||||
IR into MLIR.
|
||||
|
||||
#### `llvm.addressof`
|
||||
|
||||
Creates an SSA value containing a pointer to a global variable or constant
|
||||
defined by `llvm.global`. The global value can be defined after its first
|
||||
referenced. If the global value is a constant, storing into it is not allowed.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
func @foo() {
|
||||
// Get the address of a global.
|
||||
%0 = llvm.addressof @const : !llvm<"i32*">
|
||||
|
||||
// Use it as a regular pointer.
|
||||
%1 = llvm.load %0 : !llvm<"i32*">
|
||||
}
|
||||
|
||||
// Define the global.
|
||||
llvm.global @const(42 : i32) : !llvm.i32
|
||||
```
|
||||
|
||||
#### `llvm.constant`
|
||||
|
||||
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all
|
||||
|
|
|
@ -192,7 +192,7 @@ def FCmpPredicate : I64EnumAttr<
|
|||
[FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE,
|
||||
FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD,
|
||||
FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
|
||||
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
|
||||
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
|
||||
]> {
|
||||
let cppNamespace = "mlir::LLVM";
|
||||
|
||||
|
@ -394,6 +394,32 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
|
|||
|
||||
// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to
|
||||
// work correctly).
|
||||
def LLVM_AddressOfOp
|
||||
: LLVM_OneResultOp<"addressof">,
|
||||
Arguments<(ins SymbolRefAttr:$global_name)> {
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, "
|
||||
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
|
||||
result->addAttribute("global_name", builder->getSymbolRefAttr(name));
|
||||
result->addAttributes(attrs);
|
||||
result->addTypes(resType);}]>,
|
||||
|
||||
OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, "
|
||||
"ArrayRef<NamedAttribute> attrs = {}", [{
|
||||
build(builder, result, global.getType().getPointerTo(), global.sym_name(),
|
||||
attrs);}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the llvm.global operation that defined the value referenced here.
|
||||
GlobalOp getGlobal();
|
||||
}];
|
||||
|
||||
let printer = "printAddressOfOp(p, *this);";
|
||||
let parser = "return parseAddressOfOp(parser, result);";
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
||||
def LLVM_GlobalOp
|
||||
: LLVM_ZeroResultOp<"global">,
|
||||
Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,
|
||||
|
|
|
@ -89,6 +89,9 @@ private:
|
|||
ModuleOp mlirModule;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
|
||||
// Mappings between llvm.global definitions and corresponding globals.
|
||||
llvm::DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
|
||||
|
||||
protected:
|
||||
// Mappings between original and translated values, used for lookups.
|
||||
llvm::StringMap<llvm::Function *> functionMapping;
|
||||
|
|
|
@ -788,6 +788,49 @@ static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printer, parser and verifier for LLVM::AddressOfOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp AddressOfOp::getGlobal() {
|
||||
auto module = getParentOfType<ModuleOp>();
|
||||
assert(module && "unexpected operation outside of a module");
|
||||
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
|
||||
}
|
||||
|
||||
static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
|
||||
*p << op.getOperationName() << " @" << op.global_name();
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
|
||||
*p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseAddressOfOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
Attribute symRef;
|
||||
Type type;
|
||||
if (parser->parseAttribute(symRef, "global_name", result->attributes) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type) ||
|
||||
parser->addTypeToList(type, result->types))
|
||||
return failure();
|
||||
|
||||
if (!symRef.isa<SymbolRefAttr>())
|
||||
return parser->emitError(parser->getNameLoc(), "expected symbol reference");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AddressOfOp op) {
|
||||
auto global = op.getGlobal();
|
||||
if (!global)
|
||||
return op.emitOpError("must reference a global defined by 'llvm.global'");
|
||||
|
||||
if (global.getType().getPointerTo() != op.getResult()->getType())
|
||||
return op.emitOpError(
|
||||
"the type must be a pointer to the type of the referred global");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::ConstantOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -247,6 +247,18 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Emit addressof. We need to look up the global value referenced by the
|
||||
// operation and store it in the MLIR-to-LLVM value mapping. This does not
|
||||
// emit any LLVM instruction.
|
||||
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
|
||||
LLVM::GlobalOp global = addressOfOp.getGlobal();
|
||||
// The verifier should not have allowed this.
|
||||
assert(global && "referencing an undefined global");
|
||||
|
||||
valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
|
||||
return success();
|
||||
}
|
||||
|
||||
return opInst.emitError("unsupported or non-LLVM operation: ")
|
||||
<< opInst.getName();
|
||||
}
|
||||
|
@ -290,21 +302,23 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
|
|||
// Create named global variables that correspond to llvm.global definitions.
|
||||
void ModuleTranslation::convertGlobals() {
|
||||
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
|
||||
llvm::Constant *cst;
|
||||
llvm::Type *type;
|
||||
// String attributes are treated separately because they cannot appear as
|
||||
// in-function constants and are thus not supported by getLLVMConstant.
|
||||
if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
|
||||
llvm::Constant *cst = llvm::ConstantDataArray::getString(
|
||||
cst = llvm::ConstantDataArray::getString(
|
||||
llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
|
||||
new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(),
|
||||
llvm::GlobalValue::InternalLinkage, cst,
|
||||
op.sym_name());
|
||||
return;
|
||||
type = cst->getType();
|
||||
} else {
|
||||
type = op.getType().getUnderlyingType();
|
||||
cst = getLLVMConstant(type, op.value(), op.getLoc());
|
||||
}
|
||||
|
||||
llvm::Type *type = op.getType().getUnderlyingType();
|
||||
new llvm::GlobalVariable(
|
||||
*llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
|
||||
getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name());
|
||||
auto *var = new llvm::GlobalVariable(*llvmModule, type, op.constant(),
|
||||
llvm::GlobalValue::InternalLinkage,
|
||||
cst, op.sym_name());
|
||||
globalsMapping.try_emplace(op, var);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,17 @@ llvm.global constant @string("foobar") : !llvm<"[6 x i8]">
|
|||
// CHECK: llvm.global @string_notype("1234567")
|
||||
llvm.global @string_notype("1234567")
|
||||
|
||||
// CHECK-LABEL: references
|
||||
func @references() {
|
||||
// CHECK: llvm.addressof @global : !llvm<"i64*">
|
||||
%0 = llvm.addressof @global : !llvm<"i64*">
|
||||
|
||||
// CHECK: llvm.addressof @string : !llvm<"[6 x i8]*">
|
||||
%1 = llvm.addressof @string : !llvm<"[6 x i8]*">
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{op requires attribute 'sym_name'}}
|
||||
|
@ -54,3 +65,36 @@ llvm.global @i64_needs_type(0: i64)
|
|||
// expected-error @+1 {{expected zero or one type}}
|
||||
llvm.global @more_than_one_type(0) : !llvm.i64, !llvm.i32
|
||||
|
||||
// -----
|
||||
|
||||
llvm.global @foo(0: i32) : !llvm.i32
|
||||
|
||||
func @bar() {
|
||||
// expected-error @+2{{expected ':'}}
|
||||
llvm.addressof @foo
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo() {
|
||||
// The attribute parser will consume the first colon-type, so we put two of
|
||||
// them to trigger the attribute type mismatch error.
|
||||
// expected-error @+1 {{expected symbol reference}}
|
||||
llvm.addressof "foo" : i64 : !llvm<"void ()*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo() {
|
||||
// expected-error @+1 {{must reference a global defined by 'llvm.global'}}
|
||||
llvm.addressof @foo : !llvm<"void ()*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.global @foo(0: i32) : !llvm.i32
|
||||
|
||||
func @bar() {
|
||||
// expected-error @+1 {{the type must be a pointer to the type of the referred global}}
|
||||
llvm.addressof @foo : !llvm<"i64*">
|
||||
}
|
||||
|
|
|
@ -33,6 +33,23 @@ func @empty() {
|
|||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @global_refs
|
||||
func @global_refs() {
|
||||
// Check load from globals.
|
||||
// CHECK: load i32, i32* @i32_global
|
||||
%0 = llvm.addressof @i32_global : !llvm<"i32*">
|
||||
%1 = llvm.load %0 : !llvm<"i32*">
|
||||
|
||||
// Check the contracted form of load from array constants.
|
||||
// CHECK: load i8, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @string_const, i64 0, i64 0)
|
||||
%2 = llvm.addressof @string_const : !llvm<"[6 x i8]*">
|
||||
%c0 = llvm.constant(0 : index) : !llvm.i64
|
||||
%3 = llvm.getelementptr %2[%c0, %c0] : (!llvm<"[6 x i8]*">, !llvm.i64, !llvm.i64) -> !llvm<"i8*">
|
||||
%4 = llvm.load %3 : !llvm<"i8*">
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: declare void @body(i64)
|
||||
func @body(!llvm.i64)
|
||||
|
||||
|
|
Loading…
Reference in New Issue