[MLIR][LLVM] Add llvm.mlir.global_ctors/dtors and translation support

Add llvm.mlir.global_ctors and global_dtors ops and their translation
support to LLVM global_ctors/global_dtors global variables.

Differential Revision: https://reviews.llvm.org/D112524
This commit is contained in:
Uday Bondhugula 2021-10-20 15:14:54 +05:30
parent 349295fcf3
commit 57b9b29649
7 changed files with 215 additions and 4 deletions

View File

@ -1153,6 +1153,66 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
let verifier = "return ::verify(*this);";
}
def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins FlatSymbolRefArrayAttr
: $ctors, I32ArrayAttr
: $priorities);
let summary = "LLVM dialect global_ctors.";
let description = [{
Specifies a list of constructor functions and priorities. The functions
referenced by this array will be called in ascending order of priority (i.e.
lowest first) when the module is loaded. The order of functions with the
same priority is not defined. This operation is translated to LLVM's
global_ctors global variable. The initializer functions are run at load
time. The `data` field present in LLVM's global_ctors variable is not
modeled here.
Examples:
```mlir
llvm.mlir.global_ctors {@ctor}
llvm.func @ctor() {
...
llvm.return
}
```
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
}
def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
FlatSymbolRefArrayAttr:$dtors,
I32ArrayAttr:$priorities
);
let summary = "LLVM dialect global_dtors.";
let description = [{
Specifies a list of destructor functions and priorities. The functions
referenced by this array will be called in descending order of priority (i.e.
highest first) when the module is unloaded. The order of functions with the
same priority is not defined. This operation is translated to LLVM's
global_dtors global variable. The `data` field present in LLVM's
global_dtors variable is not modeled here.
Examples:
```mlir
llvm.func @dtor() {
llvm.return
}
llvm.mlir.global_dtors {@dtor}
```
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
}
def LLVM_LLVMFuncOp : LLVM_Op<"func",
[AutomaticAllocationScope, IsolatedFromAbove, FunctionLike, Symbol]> {
let summary = "LLVM dialect function.";

View File

@ -1641,6 +1641,11 @@ def SymbolRefArrayAttr :
let constBuilderCall = ?;
}
def FlatSymbolRefArrayAttr :
TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
let constBuilderCall = ?;
}
//===----------------------------------------------------------------------===//
// Derive attribute kinds

View File

@ -70,6 +70,22 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
}
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
/// fully defined llvm.func.
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
Operation *op,
SymbolTableCollection &symbolTable) {
StringRef name = symbol.getValue();
auto func =
symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
if (!func)
return op->emitOpError("'")
<< name << "' does not reference a valid LLVM function";
if (func.isExternal())
return op->emitOpError("'") << name << "' does not have a definition";
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CmpOp.
//===----------------------------------------------------------------------===//
@ -1624,6 +1640,48 @@ static LogicalResult verify(GlobalOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// LLVM::GlobalCtorsOp
//===----------------------------------------------------------------------===//
LogicalResult
GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute ctor : ctors()) {
if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
symbolTable)))
return failure();
}
return success();
}
static LogicalResult verify(GlobalCtorsOp op) {
if (op.ctors().size() != op.priorities().size())
return op.emitError(
"mismatch between the number of ctors and the number of priorities");
return success();
}
//===----------------------------------------------------------------------===//
// LLVM::GlobalDtorsOp
//===----------------------------------------------------------------------===//
LogicalResult
GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute dtor : dtors()) {
if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
symbolTable)))
return failure();
}
return success();
}
static LogicalResult verify(GlobalDtorsOp op) {
if (op.dtors().size() != op.priorities().size())
return op.emitError(
"mismatch between the number of dtors and the number of priorities");
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ShuffleVectorOp.
//===----------------------------------------------------------------------===//
@ -2353,7 +2411,7 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
static constexpr const FastmathFlags FastmathFlagsList[] = {
static constexpr const FastmathFlags fastmathFlagsList[] = {
// clang-format off
FastmathFlags::nnan,
FastmathFlags::ninf,
@ -2368,7 +2426,7 @@ static constexpr const FastmathFlags FastmathFlagsList[] = {
void FMFAttr::print(DialectAsmPrinter &printer) const {
printer << "fastmath<";
auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
return bitEnumContains(this->getFlags(), flag);
});
llvm::interleaveComma(flags, printer,

View File

@ -42,6 +42,7 @@
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace mlir;
using namespace mlir::LLVM;
@ -556,7 +557,7 @@ static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
}
/// Create named global variables that correspond to llvm.mlir.global
/// definitions.
/// definitions. Convert llvm.global_ctors and global_dtors ops.
LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
llvm::Type *type = convertType(op.getType());
@ -625,6 +626,26 @@ LogicalResult ModuleTranslation::convertGlobals() {
}
}
// Convert llvm.mlir.global_ctors and dtors.
for (Operation &op : getModuleBody(mlirModule)) {
auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
if (!ctorOp && !dtorOp)
continue;
auto range = ctorOp ? llvm::zip(ctorOp.ctors(), ctorOp.priorities())
: llvm::zip(dtorOp.dtors(), dtorOp.priorities());
auto appendGlobalFn =
ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
for (auto symbolAndPriority : range) {
llvm::Function *f = lookupFunction(
std::get<0>(symbolAndPriority).cast<FlatSymbolRefAttr>().getValue());
appendGlobalFn(
*llvmModule.get(), f,
std::get<1>(symbolAndPriority).cast<IntegerAttr>().getInt(),
/*Data=*/nullptr);
}
}
return success();
}
@ -1028,7 +1049,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
// Convert other top-level operations if possible.
llvm::IRBuilder<> llvmBuilder(llvmContext);
for (Operation &o : getModuleBody(module).getOperations()) {
if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::MetadataOp>(&o) &&
if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
LLVM::GlobalDtorsOp, LLVM::MetadataOp>(&o) &&
!o.hasTrait<OpTrait::IsTerminator>() &&
failed(translator.convertOperation(o, llvmBuilder))) {
return nullptr;

View File

@ -209,3 +209,21 @@ func @mismatch_addr_space() {
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm.ptr<i64, 4>
}
// -----
llvm.func @ctor() {
llvm.return
}
// CHECK: llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]}
llvm.mlir.global_ctors { ctors = [@ctor], priorities = [0 : i32]}
// -----
llvm.func @dtor() {
llvm.return
}
// CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]}

View File

@ -5,6 +5,36 @@ llvm.mlir.global private @invalid_global_alignment(42 : i64) {alignment = 63} :
// -----
llvm.func @ctor() {
llvm.return
}
// expected-error@+1{{mismatch between the number of ctors and the number of priorities}}
llvm.mlir.global_ctors {ctors = [@ctor], priorities = []}
// -----
llvm.func @dtor() {
llvm.return
}
// expected-error@+1{{mismatch between the number of dtors and the number of priorities}}
llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32, 32767 : i32]}
// -----
// expected-error@+1{{'ctor' does not reference a valid LLVM function}}
llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]}
// -----
llvm.func @dtor()
// expected-error@+1{{'dtor' does not have a definition}}
llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
// -----
// expected-error@+1{{expected llvm.noalias argument attribute to be a unit attribute}}
func @invalid_noalias(%arg0: i32 {llvm.noalias = 3}) {
"llvm.return"() : () -> ()

View File

@ -1384,6 +1384,24 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
// -----
// CHECK: @llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }]
llvm.mlir.global_ctors { ctors = [@foo], priorities = [0 : i32]}
llvm.func @foo() {
llvm.return
}
// -----
// CHECK: @llvm.global_dtors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }]
llvm.mlir.global_dtors { dtors = [@foo], priorities = [0 : i32]}
llvm.func @foo() {
llvm.return
}
// -----
// Check that branch weight attributes are exported properly as metadata.
llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// CHECK: !prof ![[NODE:[0-9]+]]