LLVM dialect: introduce llvm.addressof to access globals

This instruction is a local counterpart of llvm.global that takes a symbol
reference to a global and produces an SSA value containing the pointer to it.
Used in combination, these two operations allow one to use globals with other
operations expecting SSA values.  At a cost of IR indirection, we make sure the
functions don't implicitly capture the surrounding SSA values and remain
suitable for parallel processing.

PiperOrigin-RevId: 262908622
This commit is contained in:
Alex Zinenko 2019-08-12 06:10:29 -07:00 committed by A. Unique TensorFlower
parent 252ada4932
commit 2dd38b09c1
7 changed files with 178 additions and 10 deletions

View File

@ -255,6 +255,27 @@ Selection: `select <condition>, <lhs>, <rhs>`.
These operations do not have LLVM IR counterparts but are necessary to map LLVM
IR into MLIR.
#### `llvm.addressof`
Creates an SSA value containing a pointer to a global variable or constant
defined by `llvm.global`. The global value can be defined after its first
referenced. If the global value is a constant, storing into it is not allowed.
Examples:
```mlir {.mlir}
func @foo() {
// Get the address of a global.
%0 = llvm.addressof @const : !llvm<"i32*">
// Use it as a regular pointer.
%1 = llvm.load %0 : !llvm<"i32*">
}
// Define the global.
llvm.global @const(42 : i32) : !llvm.i32
```
#### `llvm.constant`
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all

View File

@ -192,7 +192,7 @@ def FCmpPredicate : I64EnumAttr<
[FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE,
FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD,
FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
]> {
let cppNamespace = "mlir::LLVM";
@ -394,6 +394,32 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to
// work correctly).
def LLVM_AddressOfOp
: LLVM_OneResultOp<"addressof">,
Arguments<(ins SymbolRefAttr:$global_name)> {
let builders = [
OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, "
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
result->addAttribute("global_name", builder->getSymbolRefAttr(name));
result->addAttributes(attrs);
result->addTypes(resType);}]>,
OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, "
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result, global.getType().getPointerTo(), global.sym_name(),
attrs);}]>
];
let extraClassDeclaration = [{
/// Return the llvm.global operation that defined the value referenced here.
GlobalOp getGlobal();
}];
let printer = "printAddressOfOp(p, *this);";
let parser = "return parseAddressOfOp(parser, result);";
let verifier = "return ::verify(*this);";
}
def LLVM_GlobalOp
: LLVM_ZeroResultOp<"global">,
Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,

View File

@ -89,6 +89,9 @@ private:
ModuleOp mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between llvm.global definitions and corresponding globals.
llvm::DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
protected:
// Mappings between original and translated values, used for lookups.
llvm::StringMap<llvm::Function *> functionMapping;

View File

@ -788,6 +788,49 @@ static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
return success();
}
//===----------------------------------------------------------------------===//
// Printer, parser and verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
auto module = getParentOfType<ModuleOp>();
assert(module && "unexpected operation outside of a module");
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
}
static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
*p << op.getOperationName() << " @" << op.global_name();
p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
*p << " : " << op.getResult()->getType();
}
static ParseResult parseAddressOfOp(OpAsmParser *parser,
OperationState *result) {
Attribute symRef;
Type type;
if (parser->parseAttribute(symRef, "global_name", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->addTypeToList(type, result->types))
return failure();
if (!symRef.isa<SymbolRefAttr>())
return parser->emitError(parser->getNameLoc(), "expected symbol reference");
return success();
}
static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
if (!global)
return op.emitOpError("must reference a global defined by 'llvm.global'");
if (global.getType().getPointerTo() != op.getResult()->getType())
return op.emitOpError(
"the type must be a pointer to the type of the referred global");
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//

View File

@ -247,6 +247,18 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
return success();
}
// Emit addressof. We need to look up the global value referenced by the
// operation and store it in the MLIR-to-LLVM value mapping. This does not
// emit any LLVM instruction.
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
LLVM::GlobalOp global = addressOfOp.getGlobal();
// The verifier should not have allowed this.
assert(global && "referencing an undefined global");
valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
return success();
}
return opInst.emitError("unsupported or non-LLVM operation: ")
<< opInst.getName();
}
@ -290,21 +302,23 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
// Create named global variables that correspond to llvm.global definitions.
void ModuleTranslation::convertGlobals() {
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
llvm::Constant *cst;
llvm::Type *type;
// String attributes are treated separately because they cannot appear as
// in-function constants and are thus not supported by getLLVMConstant.
if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
llvm::Constant *cst = llvm::ConstantDataArray::getString(
cst = llvm::ConstantDataArray::getString(
llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(),
llvm::GlobalValue::InternalLinkage, cst,
op.sym_name());
return;
type = cst->getType();
} else {
type = op.getType().getUnderlyingType();
cst = getLLVMConstant(type, op.value(), op.getLoc());
}
llvm::Type *type = op.getType().getUnderlyingType();
new llvm::GlobalVariable(
*llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name());
auto *var = new llvm::GlobalVariable(*llvmModule, type, op.constant(),
llvm::GlobalValue::InternalLinkage,
cst, op.sym_name());
globalsMapping.try_emplace(op, var);
}
}

View File

@ -12,6 +12,17 @@ llvm.global constant @string("foobar") : !llvm<"[6 x i8]">
// CHECK: llvm.global @string_notype("1234567")
llvm.global @string_notype("1234567")
// CHECK-LABEL: references
func @references() {
// CHECK: llvm.addressof @global : !llvm<"i64*">
%0 = llvm.addressof @global : !llvm<"i64*">
// CHECK: llvm.addressof @string : !llvm<"[6 x i8]*">
%1 = llvm.addressof @string : !llvm<"[6 x i8]*">
llvm.return
}
// -----
// expected-error @+1 {{op requires attribute 'sym_name'}}
@ -54,3 +65,36 @@ llvm.global @i64_needs_type(0: i64)
// expected-error @+1 {{expected zero or one type}}
llvm.global @more_than_one_type(0) : !llvm.i64, !llvm.i32
// -----
llvm.global @foo(0: i32) : !llvm.i32
func @bar() {
// expected-error @+2{{expected ':'}}
llvm.addressof @foo
}
// -----
func @foo() {
// The attribute parser will consume the first colon-type, so we put two of
// them to trigger the attribute type mismatch error.
// expected-error @+1 {{expected symbol reference}}
llvm.addressof "foo" : i64 : !llvm<"void ()*">
}
// -----
func @foo() {
// expected-error @+1 {{must reference a global defined by 'llvm.global'}}
llvm.addressof @foo : !llvm<"void ()*">
}
// -----
llvm.global @foo(0: i32) : !llvm.i32
func @bar() {
// expected-error @+1 {{the type must be a pointer to the type of the referred global}}
llvm.addressof @foo : !llvm<"i64*">
}

View File

@ -33,6 +33,23 @@ func @empty() {
llvm.return
}
// CHECK-LABEL: @global_refs
func @global_refs() {
// Check load from globals.
// CHECK: load i32, i32* @i32_global
%0 = llvm.addressof @i32_global : !llvm<"i32*">
%1 = llvm.load %0 : !llvm<"i32*">
// Check the contracted form of load from array constants.
// CHECK: load i8, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @string_const, i64 0, i64 0)
%2 = llvm.addressof @string_const : !llvm<"[6 x i8]*">
%c0 = llvm.constant(0 : index) : !llvm.i64
%3 = llvm.getelementptr %2[%c0, %c0] : (!llvm<"[6 x i8]*">, !llvm.i64, !llvm.i64) -> !llvm<"i8*">
%4 = llvm.load %3 : !llvm<"i8*">
llvm.return
}
// CHECK-LABEL: declare void @body(i64)
func @body(!llvm.i64)