forked from OSchip/llvm-project
Remove LLVM dependency on mlir::Module and instead check Traits.
PiperOrigin-RevId: 285724678
This commit is contained in:
parent
97af932272
commit
44fc7d72b3
|
@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
|
||||||
StringRef value, LLVM::Linkage linkage,
|
StringRef value, LLVM::Linkage linkage,
|
||||||
LLVM::LLVMDialect *llvmDialect);
|
LLVM::LLVMDialect *llvmDialect);
|
||||||
|
|
||||||
|
/// LLVM requires some operations to be inside of a Module operation. This
|
||||||
|
/// function confirms that the Operation has the desired properties.
|
||||||
|
bool satisfiesLLVMModule(Operation *op);
|
||||||
|
|
||||||
} // end namespace LLVM
|
} // end namespace LLVM
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||||
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
@ -50,7 +51,9 @@ class LLVMFuncOp;
|
||||||
class ModuleTranslation {
|
class ModuleTranslation {
|
||||||
public:
|
public:
|
||||||
template <typename T = ModuleTranslation>
|
template <typename T = ModuleTranslation>
|
||||||
static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
|
static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
|
||||||
|
if (!satisfiesLLVMModule(m))
|
||||||
|
return nullptr;
|
||||||
if (failed(checkSupportedModuleOps(m)))
|
if (failed(checkSupportedModuleOps(m)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto llvmModule = prepareLLVMModule(m);
|
auto llvmModule = prepareLLVMModule(m);
|
||||||
|
@ -66,23 +69,30 @@ public:
|
||||||
return std::move(translator.llvmModule);
|
return std::move(translator.llvmModule);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A helper method to get the single Block in an operation honoring LLVM's
|
||||||
|
/// module requirements.
|
||||||
|
static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
|
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
|
||||||
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
|
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
|
||||||
// LLVMContext, the LLVM IR module will be created in that context.
|
// LLVMContext, the LLVM IR module will be created in that context.
|
||||||
explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
|
explicit ModuleTranslation(Operation *module) : mlirModule(module) {
|
||||||
|
assert(satisfiesLLVMModule(mlirModule) &&
|
||||||
|
"mlirModule should honor LLVM's module semantics.");
|
||||||
|
}
|
||||||
virtual ~ModuleTranslation() {}
|
virtual ~ModuleTranslation() {}
|
||||||
|
|
||||||
virtual LogicalResult convertOperation(Operation &op,
|
virtual LogicalResult convertOperation(Operation &op,
|
||||||
llvm::IRBuilder<> &builder);
|
llvm::IRBuilder<> &builder);
|
||||||
static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
|
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
|
||||||
|
|
||||||
template <typename Range>
|
template <typename Range>
|
||||||
SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
|
SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Check whether the module contains only supported ops directly in its body.
|
/// Check whether the module contains only supported ops directly in its body.
|
||||||
static LogicalResult checkSupportedModuleOps(ModuleOp m);
|
static LogicalResult checkSupportedModuleOps(Operation *m);
|
||||||
|
|
||||||
LogicalResult convertFunctions();
|
LogicalResult convertFunctions();
|
||||||
void convertGlobals();
|
void convertGlobals();
|
||||||
|
@ -94,7 +104,7 @@ private:
|
||||||
Location loc);
|
Location loc);
|
||||||
|
|
||||||
// Original and translated module.
|
// Original and translated module.
|
||||||
ModuleOp mlirModule;
|
Operation *mlirModule;
|
||||||
std::unique_ptr<llvm::Module> llvmModule;
|
std::unique_ptr<llvm::Module> llvmModule;
|
||||||
|
|
||||||
// Mappings between llvm.mlir.global definitions and corresponding globals.
|
// Mappings between llvm.mlir.global definitions and corresponding globals.
|
||||||
|
|
|
@ -30,14 +30,14 @@ class Module;
|
||||||
} // namespace llvm
|
} // namespace llvm
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class ModuleOp;
|
class Operation;
|
||||||
|
|
||||||
/// Convert the given MLIR module into NVVM IR. This conversion requires the
|
/// Convert the given LLVM-module-like operation into NVVM IR. This conversion
|
||||||
/// registration of the LLVM IR dialect and will extract the LLVM context
|
/// requires the registration of the LLVM IR dialect and will extract the LLVM
|
||||||
/// from the registered LLVM IR dialect. In case of error, report it
|
/// context from the registered LLVM IR dialect. In case of error, report it to
|
||||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
/// the error handler registered with the MLIR context, if any (obtained from
|
||||||
/// the MLIR module), and return `nullptr`.
|
/// the MLIR module), and return `nullptr`.
|
||||||
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
|
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -31,14 +31,14 @@ class Module;
|
||||||
} // namespace llvm
|
} // namespace llvm
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class ModuleOp;
|
class Operation;
|
||||||
|
|
||||||
/// Convert the given MLIR module into ROCDL IR. This conversion requires the
|
/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
|
||||||
/// registration of the LLVM IR dialect and will extract the LLVM context
|
/// requires the registration of the LLVM IR dialect and will extract the LLVM
|
||||||
/// from the registered LLVM IR dialect. In case of error, report it
|
/// context from the registered LLVM IR dialect. In case of error, report it to
|
||||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
/// the error handler registered with the MLIR context, if any (obtained from
|
||||||
/// the MLIR module), and return `nullptr`.
|
/// the MLIR module), and return `nullptr`.
|
||||||
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m);
|
std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
GlobalOp AddressOfOp::getGlobal() {
|
GlobalOp AddressOfOp::getGlobal() {
|
||||||
auto module = getParentOfType<ModuleOp>();
|
Operation *module = getParentOp();
|
||||||
|
while (module && !satisfiesLLVMModule(module))
|
||||||
|
module = module->getParentOp();
|
||||||
assert(module && "unexpected operation outside of a module");
|
assert(module && "unexpected operation outside of a module");
|
||||||
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
|
return dyn_cast_or_null<LLVM::GlobalOp>(
|
||||||
|
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
||||||
|
@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) {
|
||||||
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
|
if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
"expects type to be a valid element type for an LLVM pointer");
|
"expects type to be a valid element type for an LLVM pointer");
|
||||||
if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
|
if (op.getParentOp() &&
|
||||||
|
!(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
|
||||||
|
op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
|
||||||
return op.emitOpError("must appear at the module level");
|
return op.emitOpError("must appear at the module level");
|
||||||
|
|
||||||
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
|
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
|
||||||
|
@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
|
||||||
ArrayRef<Value *>({cst0, cst0}));
|
ArrayRef<Value *>({cst0, cst0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
|
||||||
|
return op->hasTrait<OpTrait::SymbolTable>() &&
|
||||||
|
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
|
||||||
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
|
||||||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ModuleTranslation(ModuleOp module)
|
explicit ModuleTranslation(Operation *module)
|
||||||
: LLVM::ModuleTranslation(module) {}
|
: LLVM::ModuleTranslation(module) {}
|
||||||
~ModuleTranslation() override {}
|
~ModuleTranslation() override {}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ protected:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
|
||||||
ModuleTranslation translation(m);
|
ModuleTranslation translation(m);
|
||||||
auto llvmModule =
|
auto llvmModule =
|
||||||
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
||||||
|
@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
||||||
|
|
||||||
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
||||||
// function as a kernel.
|
// function as a kernel.
|
||||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
for (auto func :
|
||||||
|
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
|
||||||
if (!gpu::GPUDialect::isKernel(func))
|
if (!gpu::GPUDialect::isKernel(func))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder,
|
||||||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ModuleTranslation(ModuleOp module)
|
explicit ModuleTranslation(Operation *module)
|
||||||
: LLVM::ModuleTranslation(module) {}
|
: LLVM::ModuleTranslation(module) {}
|
||||||
~ModuleTranslation() override {}
|
~ModuleTranslation() override {}
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ protected:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
|
||||||
ModuleTranslation translation(m);
|
ModuleTranslation translation(m);
|
||||||
|
|
||||||
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
|
// lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
|
||||||
|
@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
||||||
// foreach GPU kernel
|
// foreach GPU kernel
|
||||||
// 1. Insert AMDGPU_KERNEL calling convention.
|
// 1. Insert AMDGPU_KERNEL calling convention.
|
||||||
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
||||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
for (auto func :
|
||||||
|
ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
|
||||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
|
@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
|
||||||
// Create named global variables that correspond to llvm.mlir.global
|
// Create named global variables that correspond to llvm.mlir.global
|
||||||
// definitions.
|
// definitions.
|
||||||
void ModuleTranslation::convertGlobals() {
|
void ModuleTranslation::convertGlobals() {
|
||||||
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
|
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
|
||||||
llvm::Type *type = op.getType().getUnderlyingType();
|
llvm::Type *type = op.getType().getUnderlyingType();
|
||||||
llvm::Constant *cst = llvm::UndefValue::get(type);
|
llvm::Constant *cst = llvm::UndefValue::get(type);
|
||||||
if (op.getValueOrNull()) {
|
if (op.getValueOrNull()) {
|
||||||
|
@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
|
LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
|
||||||
for (Operation &o : m.getBody()->getOperations())
|
for (Operation &o : getModuleBody(m).getOperations())
|
||||||
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
|
if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
|
||||||
!isa<ModuleTerminatorOp>(&o))
|
!o.isKnownTerminator())
|
||||||
return o.emitOpError("unsupported module-level operation");
|
return o.emitOpError("unsupported module-level operation");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
|
||||||
LogicalResult ModuleTranslation::convertFunctions() {
|
LogicalResult ModuleTranslation::convertFunctions() {
|
||||||
// Declare all functions first because there may be function calls that form a
|
// Declare all functions first because there may be function calls that form a
|
||||||
// call graph with cycles.
|
// call graph with cycles.
|
||||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
|
||||||
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
|
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
|
||||||
function.getName(),
|
function.getName(),
|
||||||
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
|
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
|
||||||
|
@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert functions.
|
// Convert functions.
|
||||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
|
||||||
// Ignore external functions.
|
// Ignore external functions.
|
||||||
if (function.isExternal())
|
if (function.isExternal())
|
||||||
continue;
|
continue;
|
||||||
|
@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
|
std::unique_ptr<llvm::Module>
|
||||||
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
ModuleTranslation::prepareLLVMModule(Operation *m) {
|
||||||
|
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
assert(dialect && "LLVM dialect must be registered");
|
assert(dialect && "LLVM dialect must be registered");
|
||||||
|
|
||||||
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());
|
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());
|
||||||
|
|
Loading…
Reference in New Issue