forked from OSchip/llvm-project
[mlir] Remove most uses of LLVMDialect::getModule
This prepares for the removal of llvm::Module and LLVMContext from the mlir::LLVMDialect. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D85371
This commit is contained in:
parent
d40c44e89e
commit
d3a9807674
|
@ -118,8 +118,7 @@ public:
|
||||||
unsigned getPointerBitwidth(unsigned addressSpace = 0);
|
unsigned getPointerBitwidth(unsigned addressSpace = 0);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// LLVM IR module used to parse/create types.
|
/// Pointer to the LLVM dialect.
|
||||||
llvm::Module *module;
|
|
||||||
LLVM::LLVMDialect *llvmDialect;
|
LLVM::LLVMDialect *llvmDialect;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -400,9 +399,6 @@ public:
|
||||||
/// Returns the LLVM IR context.
|
/// Returns the LLVM IR context.
|
||||||
llvm::LLVMContext &getContext() const;
|
llvm::LLVMContext &getContext() const;
|
||||||
|
|
||||||
/// Returns the LLVM IR module associated with the LLVM dialect.
|
|
||||||
llvm::Module &getModule() const;
|
|
||||||
|
|
||||||
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
|
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
|
||||||
/// defined by the used type converter.
|
/// defined by the used type converter.
|
||||||
LLVM::LLVMType getIndexType() const;
|
LLVM::LLVMType getIndexType() const;
|
||||||
|
@ -437,8 +433,8 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const;
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
|
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
|
||||||
ValueRange indices, ConversionPatternRewriter &rewriter,
|
ValueRange indices,
|
||||||
llvm::Module &module) const;
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
/// Returns the type of a pointer to an element of the memref.
|
/// Returns the type of a pointer to an element of the memref.
|
||||||
Type getElementPtrType(MemRefType type) const;
|
Type getElementPtrType(MemRefType type) const;
|
||||||
|
|
|
@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
|
||||||
llvm::LLVMContext &getLLVMContext();
|
llvm::LLVMContext &getLLVMContext();
|
||||||
llvm::Module &getLLVMModule();
|
llvm::Module &getLLVMModule();
|
||||||
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
|
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
|
||||||
|
const llvm::DataLayout &getDataLayout();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend LLVMType;
|
friend LLVMType;
|
||||||
|
|
|
@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
|
||||||
private:
|
private:
|
||||||
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
|
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
|
||||||
|
|
||||||
llvm::LLVMContext &getLLVMContext() {
|
|
||||||
return getLLVMDialect()->getLLVMContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
void initializeCachedTypes() {
|
void initializeCachedTypes() {
|
||||||
const llvm::Module &module = llvmDialect->getLLVMModule();
|
|
||||||
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
|
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
llvmPointerPointerType = llvmPointerType.getPointerTo();
|
llvmPointerPointerType = llvmPointerType.getPointerTo();
|
||||||
|
@ -79,7 +74,7 @@ private:
|
||||||
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||||
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
|
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
|
||||||
llvmDialect, module.getDataLayout().getPointerSizeInBits());
|
llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::LLVMType getVoidType() { return llvmVoidType; }
|
LLVM::LLVMType getVoidType() { return llvmVoidType; }
|
||||||
|
@ -95,9 +90,9 @@ private:
|
||||||
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
|
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
|
||||||
|
|
||||||
LLVM::LLVMType getIntPtrType() {
|
LLVM::LLVMType getIntPtrType() {
|
||||||
const llvm::Module &module = getLLVMDialect()->getLLVMModule();
|
|
||||||
return LLVM::LLVMType::getIntNTy(
|
return LLVM::LLVMType::getIntNTy(
|
||||||
getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
|
getLLVMDialect(),
|
||||||
|
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate a void pointer on the stack.
|
// Allocate a void pointer on the stack.
|
||||||
|
|
|
@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
|
||||||
private:
|
private:
|
||||||
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
|
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
|
||||||
|
|
||||||
llvm::LLVMContext &getLLVMContext() {
|
|
||||||
return getLLVMDialect()->getLLVMContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
void initializeCachedTypes() {
|
void initializeCachedTypes() {
|
||||||
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
|
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
|
||||||
|
|
|
@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
||||||
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
|
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
|
||||||
options(options) {
|
options(options) {
|
||||||
assert(llvmDialect && "LLVM IR dialect is not registered");
|
assert(llvmDialect && "LLVM IR dialect is not registered");
|
||||||
module = &llvmDialect->getLLVMModule();
|
|
||||||
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
|
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
|
||||||
this->options.indexBitwidth =
|
this->options.indexBitwidth =
|
||||||
module->getDataLayout().getPointerSizeInBits();
|
llvmDialect->getDataLayout().getPointerSizeInBits();
|
||||||
|
|
||||||
// Register conversions for the standard types.
|
// Register conversions for the standard types.
|
||||||
addConversion([&](ComplexType type) { return convertComplexType(type); });
|
addConversion([&](ComplexType type) { return convertComplexType(type); });
|
||||||
|
@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
|
||||||
|
|
||||||
/// Get the LLVM context.
|
/// Get the LLVM context.
|
||||||
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
|
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
|
||||||
return module->getContext();
|
return llvmDialect->getLLVMContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
|
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
|
||||||
|
@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
|
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
|
||||||
return module->getDataLayout().getPointerSizeInBits(addressSpace);
|
return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type LLVMTypeConverter::convertIndexType(IndexType type) {
|
Type LLVMTypeConverter::convertIndexType(IndexType type) {
|
||||||
|
@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
|
||||||
return typeConverter.getLLVMContext();
|
return typeConverter.getLLVMContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Module &ConvertToLLVMPattern::getModule() const {
|
|
||||||
return getDialect().getLLVMModule();
|
|
||||||
}
|
|
||||||
|
|
||||||
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
|
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
|
||||||
return typeConverter.getIndexType();
|
return typeConverter.getIndexType();
|
||||||
}
|
}
|
||||||
|
@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
|
||||||
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
|
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
|
Value ConvertToLLVMPattern::getDataPtr(
|
||||||
Value memRefDesc, ValueRange indices,
|
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter) const {
|
||||||
llvm::Module &module) const {
|
|
||||||
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
|
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
|
@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
||||||
auto type = loadOp.getMemRefType();
|
auto type = loadOp.getMemRefType();
|
||||||
|
|
||||||
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
||||||
transformed.indices(), rewriter, getModule());
|
transformed.indices(), rewriter);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
||||||
StoreOp::Adaptor transformed(operands);
|
StoreOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
||||||
transformed.indices(), rewriter, getModule());
|
transformed.indices(), rewriter);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
|
||||||
dataPtr);
|
dataPtr);
|
||||||
return success();
|
return success();
|
||||||
|
@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
||||||
auto type = prefetchOp.getMemRefType();
|
auto type = prefetchOp.getMemRefType();
|
||||||
|
|
||||||
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
||||||
transformed.indices(), rewriter, getModule());
|
transformed.indices(), rewriter);
|
||||||
|
|
||||||
// Replace with llvm.prefetch.
|
// Replace with llvm.prefetch.
|
||||||
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
|
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
|
||||||
|
@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
||||||
auto resultType = adaptor.value().getType();
|
auto resultType = adaptor.value().getType();
|
||||||
auto memRefType = atomicOp.getMemRefType();
|
auto memRefType = atomicOp.getMemRefType();
|
||||||
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
|
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
|
||||||
adaptor.indices(), rewriter, getModule());
|
adaptor.indices(), rewriter);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
||||||
op, resultType, *maybeKind, dataPtr, adaptor.value(),
|
op, resultType, *maybeKind, dataPtr, adaptor.value(),
|
||||||
LLVM::AtomicOrdering::acq_rel);
|
LLVM::AtomicOrdering::acq_rel);
|
||||||
|
@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
|
||||||
rewriter.setInsertionPointToEnd(initBlock);
|
rewriter.setInsertionPointToEnd(initBlock);
|
||||||
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
|
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
|
||||||
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
||||||
adaptor.indices(), rewriter, getModule());
|
adaptor.indices(), rewriter);
|
||||||
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
||||||
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
||||||
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
||||||
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
|
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
|
||||||
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
|
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
|
||||||
dialect->getLLVMModule().getDataLayout());
|
dialect->getDataLayout());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1152,7 +1152,7 @@ public:
|
||||||
// address space 0.
|
// address space 0.
|
||||||
// TODO: support alignment when possible.
|
// TODO: support alignment when possible.
|
||||||
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
||||||
adaptor.indices(), rewriter, getModule());
|
adaptor.indices(), rewriter);
|
||||||
auto vecTy =
|
auto vecTy =
|
||||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||||
Value vectorDataPtr;
|
Value vectorDataPtr;
|
||||||
|
|
|
@ -103,7 +103,7 @@ public:
|
||||||
// indices, so no need to calculat offset size in bytes again in
|
// indices, so no need to calculat offset size in bytes again in
|
||||||
// the MUBUF instruction.
|
// the MUBUF instruction.
|
||||||
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
||||||
adaptor.indices(), rewriter, getModule());
|
adaptor.indices(), rewriter);
|
||||||
|
|
||||||
// 1. Create and fill a <4 x i32> dwordConfig with:
|
// 1. Create and fill a <4 x i32> dwordConfig with:
|
||||||
// 1st two elements holding the address of dataPtr.
|
// 1st two elements holding the address of dataPtr.
|
||||||
|
|
|
@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
|
||||||
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
|
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
|
||||||
return impl->mutex;
|
return impl->mutex;
|
||||||
}
|
}
|
||||||
|
const llvm::DataLayout &LLVMDialect::getDataLayout() {
|
||||||
|
return impl->module.getDataLayout();
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse a type registered to this dialect.
|
/// Parse a type registered to this dialect.
|
||||||
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
|
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
|
Loading…
Reference in New Issue