[mlir] Remove most uses of LLVMDialect::getModule

This prepares for the removal of llvm::Module and LLVMContext from the
mlir::LLVMDialect.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85371
This commit is contained in:
Alex Zinenko 2020-08-06 00:52:10 +02:00
parent d40c44e89e
commit d3a9807674
8 changed files with 24 additions and 39 deletions

View File

@ -118,8 +118,7 @@ public:
unsigned getPointerBitwidth(unsigned addressSpace = 0); unsigned getPointerBitwidth(unsigned addressSpace = 0);
protected: protected:
/// LLVM IR module used to parse/create types. /// Pointer to the LLVM dialect.
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect; LLVM::LLVMDialect *llvmDialect;
private: private:
@ -400,9 +399,6 @@ public:
/// Returns the LLVM IR context. /// Returns the LLVM IR context.
llvm::LLVMContext &getContext() const; llvm::LLVMContext &getContext() const;
/// Returns the LLVM IR module associated with the LLVM dialect.
llvm::Module &getModule() const;
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter. /// defined by the used type converter.
LLVM::LLVMType getIndexType() const; LLVM::LLVMType getIndexType() const;
@ -437,8 +433,8 @@ public:
ConversionPatternRewriter &rewriter) const; ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ValueRange indices, ConversionPatternRewriter &rewriter, ValueRange indices,
llvm::Module &module) const; ConversionPatternRewriter &rewriter) const;
/// Returns the type of a pointer to an element of the memref. /// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const; Type getElementPtrType(MemRefType type) const;

View File

@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
llvm::LLVMContext &getLLVMContext(); llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule(); llvm::Module &getLLVMModule();
llvm::sys::SmartMutex<true> &getLLVMContextMutex(); llvm::sys::SmartMutex<true> &getLLVMContextMutex();
const llvm::DataLayout &getDataLayout();
private: private:
friend LLVMType; friend LLVMType;

View File

@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
private: private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
llvm::LLVMContext &getLLVMContext() {
return getLLVMDialect()->getLLVMContext();
}
void initializeCachedTypes() { void initializeCachedTypes() {
const llvm::Module &module = llvmDialect->getLLVMModule();
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmPointerPointerType = llvmPointerType.getPointerTo(); llvmPointerPointerType = llvmPointerType.getPointerTo();
@ -79,7 +74,7 @@ private:
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmIntPtrType = LLVM::LLVMType::getIntNTy( llvmIntPtrType = LLVM::LLVMType::getIntNTy(
llvmDialect, module.getDataLayout().getPointerSizeInBits()); llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
} }
LLVM::LLVMType getVoidType() { return llvmVoidType; } LLVM::LLVMType getVoidType() { return llvmVoidType; }
@ -95,9 +90,9 @@ private:
LLVM::LLVMType getInt64Type() { return llvmInt64Type; } LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
LLVM::LLVMType getIntPtrType() { LLVM::LLVMType getIntPtrType() {
const llvm::Module &module = getLLVMDialect()->getLLVMModule();
return LLVM::LLVMType::getIntNTy( return LLVM::LLVMType::getIntNTy(
getLLVMDialect(), module.getDataLayout().getPointerSizeInBits()); getLLVMDialect(),
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
} }
// Allocate a void pointer on the stack. // Allocate a void pointer on the stack.

View File

@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
private: private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
llvm::LLVMContext &getLLVMContext() {
return getLLVMDialect()->getLLVMContext();
}
void initializeCachedTypes() { void initializeCachedTypes() {
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);

View File

@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()), : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) { options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered"); assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
this->options.indexBitwidth = this->options.indexBitwidth =
module->getDataLayout().getPointerSizeInBits(); llvmDialect->getDataLayout().getPointerSizeInBits();
// Register conversions for the standard types. // Register conversions for the standard types.
addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](ComplexType type) { return convertComplexType(type); });
@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
/// Get the LLVM context. /// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
return module->getContext(); return llvmDialect->getLLVMContext();
} }
LLVM::LLVMType LLVMTypeConverter::getIndexType() { LLVM::LLVMType LLVMTypeConverter::getIndexType() {
@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
} }
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
return module->getDataLayout().getPointerSizeInBits(addressSpace); return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
} }
Type LLVMTypeConverter::convertIndexType(IndexType type) { Type LLVMTypeConverter::convertIndexType(IndexType type) {
@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
return typeConverter.getLLVMContext(); return typeConverter.getLLVMContext();
} }
llvm::Module &ConvertToLLVMPattern::getModule() const {
return getDialect().getLLVMModule();
}
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return typeConverter.getIndexType(); return typeConverter.getIndexType();
} }
@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue); return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
} }
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type, Value ConvertToLLVMPattern::getDataPtr(
Value memRefDesc, ValueRange indices, Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter) const {
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset; int64_t offset;
SmallVector<int64_t, 4> strides; SmallVector<int64_t, 4> strides;
@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
auto type = loadOp.getMemRefType(); auto type = loadOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule()); transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr); rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return success(); return success();
} }
@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
StoreOp::Adaptor transformed(operands); StoreOp::Adaptor transformed(operands);
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule()); transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr); dataPtr);
return success(); return success();
@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
auto type = prefetchOp.getMemRefType(); auto type = prefetchOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule()); transformed.indices(), rewriter);
// Replace with llvm.prefetch. // Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto resultType = adaptor.value().getType(); auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType(); auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule()); adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(), op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel); LLVM::AtomicOrdering::acq_rel);
@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
rewriter.setInsertionPointToEnd(initBlock); rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule()); adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock); rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

View File

@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
LLVM::LLVMDialect *dialect = typeConverter.getDialect(); LLVM::LLVMDialect *dialect = typeConverter.getDialect();
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext()) align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
dialect->getLLVMModule().getDataLayout()); dialect->getDataLayout());
return success(); return success();
} }
@ -1152,7 +1152,7 @@ public:
// address space 0. // address space 0.
// TODO: support alignment when possible. // TODO: support alignment when possible.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule()); adaptor.indices(), rewriter);
auto vecTy = auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr; Value vectorDataPtr;

View File

@ -103,7 +103,7 @@ public:
// indices, so no need to calculat offset size in bytes again in // indices, so no need to calculat offset size in bytes again in
// the MUBUF instruction. // the MUBUF instruction.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule()); adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with: // 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr. // 1st two elements holding the address of dataPtr.

View File

@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() { llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
return impl->mutex; return impl->mutex;
} }
const llvm::DataLayout &LLVMDialect::getDataLayout() {
return impl->module.getDataLayout();
}
/// Parse a type registered to this dialect. /// Parse a type registered to this dialect.
Type LLVMDialect::parseType(DialectAsmParser &parser) const { Type LLVMDialect::parseType(DialectAsmParser &parser) const {