Remove LLVM dependency on mlir::Module and instead check Traits.

PiperOrigin-RevId: 285724678
This commit is contained in:
Tres Popp 2019-12-16 01:35:03 -08:00 committed by A. Unique TensorFlower
parent 97af932272
commit 44fc7d72b3
8 changed files with 61 additions and 34 deletions

View File

@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
StringRef value, LLVM::Linkage linkage, StringRef value, LLVM::Linkage linkage,
LLVM::LLVMDialect *llvmDialect); LLVM::LLVMDialect *llvmDialect);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
bool satisfiesLLVMModule(Operation *op);
} // end namespace LLVM } // end namespace LLVM
} // end namespace mlir } // end namespace mlir

View File

@ -23,6 +23,7 @@
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H #ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@ -50,7 +51,9 @@ class LLVMFuncOp;
class ModuleTranslation { class ModuleTranslation {
public: public:
template <typename T = ModuleTranslation> template <typename T = ModuleTranslation>
static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) { static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
if (!satisfiesLLVMModule(m))
return nullptr;
if (failed(checkSupportedModuleOps(m))) if (failed(checkSupportedModuleOps(m)))
return nullptr; return nullptr;
auto llvmModule = prepareLLVMModule(m); auto llvmModule = prepareLLVMModule(m);
@ -66,23 +69,30 @@ public:
return std::move(translator.llvmModule); return std::move(translator.llvmModule);
} }
/// A helper method to get the single Block in an operation honoring LLVM's
/// module requirements.
static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
protected: protected:
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an // LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
// LLVMContext, the LLVM IR module will be created in that context. // LLVMContext, the LLVM IR module will be created in that context.
explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {} explicit ModuleTranslation(Operation *module) : mlirModule(module) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
virtual ~ModuleTranslation() {} virtual ~ModuleTranslation() {}
virtual LogicalResult convertOperation(Operation &op, virtual LogicalResult convertOperation(Operation &op,
llvm::IRBuilder<> &builder); llvm::IRBuilder<> &builder);
static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m); static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
template <typename Range> template <typename Range>
SmallVector<llvm::Value *, 8> lookupValues(Range &&values); SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
private: private:
/// Check whether the module contains only supported ops directly in its body. /// Check whether the module contains only supported ops directly in its body.
static LogicalResult checkSupportedModuleOps(ModuleOp m); static LogicalResult checkSupportedModuleOps(Operation *m);
LogicalResult convertFunctions(); LogicalResult convertFunctions();
void convertGlobals(); void convertGlobals();
@ -94,7 +104,7 @@ private:
Location loc); Location loc);
// Original and translated module. // Original and translated module.
ModuleOp mlirModule; Operation *mlirModule;
std::unique_ptr<llvm::Module> llvmModule; std::unique_ptr<llvm::Module> llvmModule;
// Mappings between llvm.mlir.global definitions and corresponding globals. // Mappings between llvm.mlir.global definitions and corresponding globals.

View File

@ -30,14 +30,14 @@ class Module;
} // namespace llvm } // namespace llvm
namespace mlir { namespace mlir {
class ModuleOp; class Operation;
/// Convert the given MLIR module into NVVM IR. This conversion requires the /// Convert the given LLVM-module-like operation into NVVM IR. This conversion
/// registration of the LLVM IR dialect and will extract the LLVM context /// requires the registration of the LLVM IR dialect and will extract the LLVM
/// from the registered LLVM IR dialect. In case of error, report it /// context from the registered LLVM IR dialect. In case of error, report it to
/// to the error handler registered with the MLIR context, if any (obtained from /// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`. /// the MLIR module), and return `nullptr`.
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m); std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
} // namespace mlir } // namespace mlir

View File

@ -31,14 +31,14 @@ class Module;
} // namespace llvm } // namespace llvm
namespace mlir { namespace mlir {
class ModuleOp; class Operation;
/// Convert the given MLIR module into ROCDL IR. This conversion requires the /// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
/// registration of the LLVM IR dialect and will extract the LLVM context /// requires the registration of the LLVM IR dialect and will extract the LLVM
/// from the registered LLVM IR dialect. In case of error, report it /// context from the registered LLVM IR dialect. In case of error, report it to
/// to the error handler registered with the MLIR context, if any (obtained from /// the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`. /// the MLIR module), and return `nullptr`.
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m); std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
} // namespace mlir } // namespace mlir

View File

@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() { GlobalOp AddressOfOp::getGlobal() {
auto module = getParentOfType<ModuleOp>(); Operation *module = getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module"); assert(module && "unexpected operation outside of a module");
return module.lookupSymbol<LLVM::GlobalOp>(global_name()); return dyn_cast_or_null<LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
} }
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) { static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) {
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
return op.emitOpError( return op.emitOpError(
"expects type to be a valid element type for an LLVM pointer"); "expects type to be a valid element type for an LLVM pointer");
if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp())) if (op.getParentOp() &&
!(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
return op.emitOpError("must appear at the module level"); return op.emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
ArrayRef<Value *>({cst0, cst0})); ArrayRef<Value *>({cst0, cst0}));
} }
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

View File

@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
class ModuleTranslation : public LLVM::ModuleTranslation { class ModuleTranslation : public LLVM::ModuleTranslation {
public: public:
explicit ModuleTranslation(ModuleOp module) explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {} : LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {} ~ModuleTranslation() override {}
@ -73,7 +73,7 @@ protected:
}; };
} // namespace } // namespace
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) { std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
ModuleTranslation translation(m); ModuleTranslation translation(m);
auto llvmModule = auto llvmModule =
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m); LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel. // function as a kernel.
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) { for (auto func :
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!gpu::GPUDialect::isKernel(func)) if (!gpu::GPUDialect::isKernel(func))
continue; continue;

View File

@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder,
class ModuleTranslation : public LLVM::ModuleTranslation { class ModuleTranslation : public LLVM::ModuleTranslation {
public: public:
explicit ModuleTranslation(ModuleOp module) explicit ModuleTranslation(Operation *module)
: LLVM::ModuleTranslation(module) {} : LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {} ~ModuleTranslation() override {}
@ -84,7 +84,7 @@ protected:
}; };
} // namespace } // namespace
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) { std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
ModuleTranslation translation(m); ModuleTranslation translation(m);
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
// foreach GPU kernel // foreach GPU kernel
// 1. Insert AMDGPU_KERNEL calling convention. // 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) { for (auto func :
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName())) if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue; continue;

View File

@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
// Create named global variables that correspond to llvm.mlir.global // Create named global variables that correspond to llvm.mlir.global
// definitions. // definitions.
void ModuleTranslation::convertGlobals() { void ModuleTranslation::convertGlobals() {
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) { for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
llvm::Type *type = op.getType().getUnderlyingType(); llvm::Type *type = op.getType().getUnderlyingType();
llvm::Constant *cst = llvm::UndefValue::get(type); llvm::Constant *cst = llvm::UndefValue::get(type);
if (op.getValueOrNull()) { if (op.getValueOrNull()) {
@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
return success(); return success();
} }
LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) { LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
for (Operation &o : m.getBody()->getOperations()) for (Operation &o : getModuleBody(m).getOperations())
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) && if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
!isa<ModuleTerminatorOp>(&o)) !o.isKnownTerminator())
return o.emitOpError("unsupported module-level operation"); return o.emitOpError("unsupported module-level operation");
return success(); return success();
} }
@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
LogicalResult ModuleTranslation::convertFunctions() { LogicalResult ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a // Declare all functions first because there may be function calls that form a
// call graph with cycles. // call graph with cycles.
for (auto function : mlirModule.getOps<LLVMFuncOp>()) { for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(), function.getName(),
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType())); llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
} }
// Convert functions. // Convert functions.
for (auto function : mlirModule.getOps<LLVMFuncOp>()) { for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
// Ignore external functions. // Ignore external functions.
if (function.isExternal()) if (function.isExternal())
continue; continue;
@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() {
return success(); return success();
} }
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) { std::unique_ptr<llvm::Module>
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); ModuleTranslation::prepareLLVMModule(Operation *m) {
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered"); assert(dialect && "LLVM dialect must be registered");
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());