forked from OSchip/llvm-project
Remove LLVM dependency on mlir::Module and instead check Traits.
PiperOrigin-RevId: 285724678
This commit is contained in:
parent
97af932272
commit
44fc7d72b3
|
@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
|
|||
StringRef value, LLVM::Linkage linkage,
|
||||
LLVM::LLVMDialect *llvmDialect);
|
||||
|
||||
/// LLVM requires some operations to be inside of a Module operation. This
|
||||
/// function confirms that the Operation has the desired properties.
|
||||
bool satisfiesLLVMModule(Operation *op);
|
||||
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
@ -50,7 +51,9 @@ class LLVMFuncOp;
|
|||
class ModuleTranslation {
|
||||
public:
|
||||
template <typename T = ModuleTranslation>
|
||||
static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
|
||||
static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
|
||||
if (!satisfiesLLVMModule(m))
|
||||
return nullptr;
|
||||
if (failed(checkSupportedModuleOps(m)))
|
||||
return nullptr;
|
||||
auto llvmModule = prepareLLVMModule(m);
|
||||
|
@ -66,23 +69,30 @@ public:
|
|||
return std::move(translator.llvmModule);
|
||||
}
|
||||
|
||||
/// A helper method to get the single Block in an operation honoring LLVM's
|
||||
/// module requirements.
|
||||
static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
|
||||
|
||||
protected:
|
||||
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
|
||||
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
|
||||
// LLVMContext, the LLVM IR module will be created in that context.
|
||||
explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
|
||||
explicit ModuleTranslation(Operation *module) : mlirModule(module) {
|
||||
assert(satisfiesLLVMModule(mlirModule) &&
|
||||
"mlirModule should honor LLVM's module semantics.");
|
||||
}
|
||||
virtual ~ModuleTranslation() {}
|
||||
|
||||
virtual LogicalResult convertOperation(Operation &op,
|
||||
llvm::IRBuilder<> &builder);
|
||||
static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
|
||||
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
|
||||
|
||||
template <typename Range>
|
||||
SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
|
||||
|
||||
private:
|
||||
/// Check whether the module contains only supported ops directly in its body.
|
||||
static LogicalResult checkSupportedModuleOps(ModuleOp m);
|
||||
static LogicalResult checkSupportedModuleOps(Operation *m);
|
||||
|
||||
LogicalResult convertFunctions();
|
||||
void convertGlobals();
|
||||
|
@ -94,7 +104,7 @@ private:
|
|||
Location loc);
|
||||
|
||||
// Original and translated module.
|
||||
ModuleOp mlirModule;
|
||||
Operation *mlirModule;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
|
||||
// Mappings between llvm.mlir.global definitions and corresponding globals.
|
||||
|
|
|
@ -30,14 +30,14 @@ class Module;
|
|||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
|
||||
/// Convert the given MLIR module into NVVM IR. This conversion requires the
|
||||
/// registration of the LLVM IR dialect and will extract the LLVM context
|
||||
/// from the registered LLVM IR dialect. In case of error, report it
|
||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
||||
/// Convert the given LLVM-module-like operation into NVVM IR. This conversion
|
||||
/// requires the registration of the LLVM IR dialect and will extract the LLVM
|
||||
/// context from the registered LLVM IR dialect. In case of error, report it to
|
||||
/// the error handler registered with the MLIR context, if any (obtained from
|
||||
/// the MLIR module), and return `nullptr`.
|
||||
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
|
||||
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -31,14 +31,14 @@ class Module;
|
|||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
|
||||
/// Convert the given MLIR module into ROCDL IR. This conversion requires the
|
||||
/// registration of the LLVM IR dialect and will extract the LLVM context
|
||||
/// from the registered LLVM IR dialect. In case of error, report it
|
||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
||||
/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
|
||||
/// requires the registration of the LLVM IR dialect and will extract the LLVM
|
||||
/// context from the registered LLVM IR dialect. In case of error, report it to
|
||||
/// the error handler registered with the MLIR context, if any (obtained from
|
||||
/// the MLIR module), and return `nullptr`.
|
||||
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m);
|
||||
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp AddressOfOp::getGlobal() {
|
||||
auto module = getParentOfType<ModuleOp>();
|
||||
Operation *module = getParentOp();
|
||||
while (module && !satisfiesLLVMModule(module))
|
||||
module = module->getParentOp();
|
||||
assert(module && "unexpected operation outside of a module");
|
||||
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
|
||||
return dyn_cast_or_null<LLVM::GlobalOp>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
|
||||
}
|
||||
|
||||
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
||||
|
@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) {
|
|||
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
|
||||
return op.emitOpError(
|
||||
"expects type to be a valid element type for an LLVM pointer");
|
||||
if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
|
||||
if (op.getParentOp() &&
|
||||
!(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
|
||||
op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
|
||||
return op.emitOpError("must appear at the module level");
|
||||
|
||||
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
|
||||
|
@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
|
|||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
|
||||
ArrayRef<Value *>({cst0, cst0}));
|
||||
}
|
||||
|
||||
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
|
||||
return op->hasTrait<OpTrait::SymbolTable>() &&
|
||||
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
|
|||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||
|
||||
public:
|
||||
explicit ModuleTranslation(ModuleOp module)
|
||||
explicit ModuleTranslation(Operation *module)
|
||||
: LLVM::ModuleTranslation(module) {}
|
||||
~ModuleTranslation() override {}
|
||||
|
||||
|
@ -73,7 +73,7 @@ protected:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
|
||||
ModuleTranslation translation(m);
|
||||
auto llvmModule =
|
||||
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
||||
|
@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
|||
|
||||
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
||||
// function as a kernel.
|
||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||
for (auto func :
|
||||
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
|
||||
if (!gpu::GPUDialect::isKernel(func))
|
||||
continue;
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder,
|
|||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||
|
||||
public:
|
||||
explicit ModuleTranslation(ModuleOp module)
|
||||
explicit ModuleTranslation(Operation *module)
|
||||
: LLVM::ModuleTranslation(module) {}
|
||||
~ModuleTranslation() override {}
|
||||
|
||||
|
@ -84,7 +84,7 @@ protected:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
|
||||
ModuleTranslation translation(m);
|
||||
|
||||
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
|
||||
|
@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
|||
// foreach GPU kernel
|
||||
// 1. Insert AMDGPU_KERNEL calling convention.
|
||||
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||
for (auto func :
|
||||
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
|
||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||
continue;
|
||||
|
||||
|
|
|
@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
|
|||
// Create named global variables that correspond to llvm.mlir.global
|
||||
// definitions.
|
||||
void ModuleTranslation::convertGlobals() {
|
||||
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
|
||||
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
|
||||
llvm::Type *type = op.getType().getUnderlyingType();
|
||||
llvm::Constant *cst = llvm::UndefValue::get(type);
|
||||
if (op.getValueOrNull()) {
|
||||
|
@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
|
||||
for (Operation &o : m.getBody()->getOperations())
|
||||
LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
|
||||
for (Operation &o : getModuleBody(m).getOperations())
|
||||
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
|
||||
!isa<ModuleTerminatorOp>(&o))
|
||||
!o.isKnownTerminator())
|
||||
return o.emitOpError("unsupported module-level operation");
|
||||
return success();
|
||||
}
|
||||
|
@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
|
|||
LogicalResult ModuleTranslation::convertFunctions() {
|
||||
// Declare all functions first because there may be function calls that form a
|
||||
// call graph with cycles.
|
||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
|
||||
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
|
||||
function.getName(),
|
||||
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
|
||||
|
@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
|
|||
}
|
||||
|
||||
// Convert functions.
|
||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
|
||||
// Ignore external functions.
|
||||
if (function.isExternal())
|
||||
continue;
|
||||
|
@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() {
|
|||
return success();
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
|
||||
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
std::unique_ptr<llvm::Module>
|
||||
ModuleTranslation::prepareLLVMModule(Operation *m) {
|
||||
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(dialect && "LLVM dialect must be registered");
|
||||
|
||||
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());
|
||||
|
|
Loading…
Reference in New Issue