Remove LLVM dependency on mlir::Module and instead check Traits.

PiperOrigin-RevId: 285724678
This commit is contained in:
Tres Popp 2019-12-16 01:35:03 -08:00 committed by A. Unique TensorFlower
parent 97af932272
commit 44fc7d72b3
8 changed files with 61 additions and 34 deletions

View File

@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
StringRef value, LLVM::Linkage linkage,
LLVM::LLVMDialect *llvmDialect);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
bool satisfiesLLVMModule(Operation *op);
} // end namespace LLVM
} // end namespace mlir

View File

@ -23,6 +23,7 @@
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Value.h"
@ -50,7 +51,9 @@ class LLVMFuncOp;
class ModuleTranslation {
public:
template <typename T = ModuleTranslation>
static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
if (!satisfiesLLVMModule(m))
return nullptr;
if (failed(checkSupportedModuleOps(m)))
return nullptr;
auto llvmModule = prepareLLVMModule(m);
@ -66,23 +69,30 @@ public:
return std::move(translator.llvmModule);
}
/// A helper method to get the single Block in an operation honoring LLVM's
/// module requirements.
static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
protected:
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
// LLVMContext, the LLVM IR module will be created in that context.
explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
explicit ModuleTranslation(Operation *module) : mlirModule(module) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
virtual ~ModuleTranslation() {}
virtual LogicalResult convertOperation(Operation &op,
llvm::IRBuilder<> &builder);
static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
template <typename Range>
SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
private:
/// Check whether the module contains only supported ops directly in its body.
static LogicalResult checkSupportedModuleOps(ModuleOp m);
static LogicalResult checkSupportedModuleOps(Operation *m);
LogicalResult convertFunctions();
void convertGlobals();
@ -94,7 +104,7 @@ private:
Location loc);
// Original and translated module.
ModuleOp mlirModule;
Operation *mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between llvm.mlir.global definitions and corresponding globals.

View File

@ -30,14 +30,14 @@ class Module;
} // namespace llvm
namespace mlir {
class ModuleOp;
class Operation;
/// Convert the given MLIR module into NVVM IR. This conversion requires the
/// registration of the LLVM IR dialect and will extract the LLVM context
/// from the registered LLVM IR dialect. In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// Convert the given LLVM-module-like operation into NVVM IR. This conversion
/// requires the registration of the LLVM IR dialect and will extract the LLVM
/// context from the registered LLVM IR dialect. In case of error, report it to
/// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
} // namespace mlir

View File

@ -31,14 +31,14 @@ class Module;
} // namespace llvm
namespace mlir {
class ModuleOp;
class Operation;
/// Convert the given MLIR module into ROCDL IR. This conversion requires the
/// registration of the LLVM IR dialect and will extract the LLVM context
/// from the registered LLVM IR dialect. In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
/// requires the registration of the LLVM IR dialect and will extract the LLVM
/// context from the registered LLVM IR dialect. In case of error, report it to
/// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m);
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
} // namespace mlir

View File

@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
auto module = getParentOfType<ModuleOp>();
Operation *module = getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
return dyn_cast_or_null<LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
}
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) {
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
return op.emitOpError(
"expects type to be a valid element type for an LLVM pointer");
if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
if (op.getParentOp() &&
!(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
return op.emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
ArrayRef<Value *>({cst0, cst0}));
}
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

View File

@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
class ModuleTranslation : public LLVM::ModuleTranslation {
public:
explicit ModuleTranslation(ModuleOp module)
explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {}
@ -73,7 +73,7 @@ protected:
};
} // namespace
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
ModuleTranslation translation(m);
auto llvmModule =
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
for (auto func :
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!gpu::GPUDialect::isKernel(func))
continue;

View File

@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder,
class ModuleTranslation : public LLVM::ModuleTranslation {
public:
explicit ModuleTranslation(ModuleOp module)
explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {}
@ -84,7 +84,7 @@ protected:
};
} // namespace
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
ModuleTranslation translation(m);
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
// foreach GPU kernel
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
for (auto func :
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;

View File

@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
// Create named global variables that correspond to llvm.mlir.global
// definitions.
void ModuleTranslation::convertGlobals() {
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
llvm::Type *type = op.getType().getUnderlyingType();
llvm::Constant *cst = llvm::UndefValue::get(type);
if (op.getValueOrNull()) {
@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
return success();
}
LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
for (Operation &o : m.getBody()->getOperations())
LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
for (Operation &o : getModuleBody(m).getOperations())
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
!isa<ModuleTerminatorOp>(&o))
!o.isKnownTerminator())
return o.emitOpError("unsupported module-level operation");
return success();
}
@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
LogicalResult ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(),
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
}
// Convert functions.
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
// Ignore external functions.
if (function.isExternal())
continue;
@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() {
return success();
}
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
std::unique_ptr<llvm::Module>
ModuleTranslation::prepareLLVMModule(Operation *m) {
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered");
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());