[mlir] Remove most uses of LLVMDialect::getModule

This prepares for the removal of llvm::Module and LLVMContext from the
mlir::LLVMDialect.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85371
This commit is contained in:
Alex Zinenko 2020-08-06 00:52:10 +02:00
parent d40c44e89e
commit d3a9807674
8 changed files with 24 additions and 39 deletions

View File

@ -118,8 +118,7 @@ public:
unsigned getPointerBitwidth(unsigned addressSpace = 0);
protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
private:
@ -400,9 +399,6 @@ public:
/// Returns the LLVM IR context.
llvm::LLVMContext &getContext() const;
/// Returns the LLVM IR module associated with the LLVM dialect.
llvm::Module &getModule() const;
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
LLVM::LLVMType getIndexType() const;
@ -437,8 +433,8 @@ public:
ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ValueRange indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const;
ValueRange indices,
ConversionPatternRewriter &rewriter) const;
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const;

View File

@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule();
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
const llvm::DataLayout &getDataLayout();
private:
friend LLVMType;

View File

@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
llvm::LLVMContext &getLLVMContext() {
return getLLVMDialect()->getLLVMContext();
}
void initializeCachedTypes() {
const llvm::Module &module = llvmDialect->getLLVMModule();
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmPointerPointerType = llvmPointerType.getPointerTo();
@ -79,7 +74,7 @@ private:
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
llvmDialect, module.getDataLayout().getPointerSizeInBits());
llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
@ -95,9 +90,9 @@ private:
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
LLVM::LLVMType getIntPtrType() {
const llvm::Module &module = getLLVMDialect()->getLLVMModule();
return LLVM::LLVMType::getIntNTy(
getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
getLLVMDialect(),
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
// Allocate a void pointer on the stack.

View File

@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
llvm::LLVMContext &getLLVMContext() {
return getLLVMDialect()->getLLVMContext();
}
void initializeCachedTypes() {
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);

View File

@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
this->options.indexBitwidth =
module->getDataLayout().getPointerSizeInBits();
llvmDialect->getDataLayout().getPointerSizeInBits();
// Register conversions for the standard types.
addConversion([&](ComplexType type) { return convertComplexType(type); });
@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
/// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
return module->getContext();
return llvmDialect->getLLVMContext();
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
return module->getDataLayout().getPointerSizeInBits(addressSpace);
return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
return typeConverter.getLLVMContext();
}
llvm::Module &ConvertToLLVMPattern::getModule() const {
return getDialect().getLLVMModule();
}
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return typeConverter.getIndexType();
}
@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
Value ConvertToLLVMPattern::getDataPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
auto type = loadOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return success();
}
@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
StoreOp::Adaptor transformed(operands);
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return success();
@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
auto type = prefetchOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
transformed.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

View File

@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
dialect->getLLVMModule().getDataLayout());
dialect->getDataLayout());
return success();
}
@ -1152,7 +1152,7 @@ public:
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;

View File

@ -103,7 +103,7 @@ public:
// indices, so no need to calculat offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.

View File

@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
return impl->mutex;
}
const llvm::DataLayout &LLVMDialect::getDataLayout() {
return impl->module.getDataLayout();
}
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(DialectAsmParser &parser) const {