[mlir] take MLIRContext instead of LLVMDialect in getters of LLVMType's

Historical modeling of the LLVM dialect types had been wrapping LLVM IR types
and therefore needed access to the instance of LLVMContext stored in the
LLVMDialect. The new modeling does not rely on that and only needs the
MLIRContext that is used for uniquing, similarly to other MLIR types. Change
LLVMType::get<Kind>Ty functions to take `MLIRContext *` instead of
`LLVMDialect *` as first argument. This brings the code base closer to
completely removing the dependence on LLVMContext from the LLVMDialect,
together with additional support for thread-safety of its use.

Depends On D85371

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85372
This commit is contained in:
Alex Zinenko 2020-08-06 00:52:20 +02:00
parent d3a9807674
commit 5446ec8507
19 changed files with 185 additions and 234 deletions

View File

@ -56,19 +56,15 @@ public:
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
llvmDialect);
loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
@ -108,16 +104,15 @@ private:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
@ -132,15 +127,14 @@ private:
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@ -149,10 +143,10 @@ private:
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};

View File

@ -56,19 +56,15 @@ public:
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
llvmDialect);
loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
@ -108,16 +104,15 @@ private:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
@ -132,15 +127,14 @@ private:
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@ -149,10 +143,10 @@ private:
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};

View File

@ -59,8 +59,7 @@ struct LLVMDialectImpl;
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
StringRef value, LLVM::Linkage linkage,
LLVM::LLVMDialect *llvmDialect);
StringRef value, LLVM::Linkage linkage);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.

View File

@ -58,8 +58,7 @@ class LLVMI<int width>
"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
"LLVM dialect " # width # "-bit integer">,
BuildableType<
"::mlir::LLVM::LLVMType::getIntNTy("
"$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
"::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
# width # ")">;
def LLVMI1 : LLVMI<1>;

View File

@ -151,8 +151,7 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ICmpPredicate predicate, Value lhs, "
"Value rhs", [{
LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
build(b, result, LLVMType::getInt1Ty(dialect),
build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
@ -198,8 +197,7 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, FCmpPredicate predicate, Value lhs, "
"Value rhs", [{
LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
build(b, result, LLVMType::getInt1Ty(dialect),
build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];

View File

@ -152,32 +152,32 @@ public:
bool isStructTy();
/// Utilities used to generate floating point types.
static LLVMType getDoubleTy(LLVMDialect *dialect);
static LLVMType getFloatTy(LLVMDialect *dialect);
static LLVMType getBFloatTy(LLVMDialect *dialect);
static LLVMType getHalfTy(LLVMDialect *dialect);
static LLVMType getFP128Ty(LLVMDialect *dialect);
static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
static LLVMType getDoubleTy(MLIRContext *context);
static LLVMType getFloatTy(MLIRContext *context);
static LLVMType getBFloatTy(MLIRContext *context);
static LLVMType getHalfTy(MLIRContext *context);
static LLVMType getFP128Ty(MLIRContext *context);
static LLVMType getX86_FP80Ty(MLIRContext *context);
/// Utilities used to generate integer types.
static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
static LLVMType getInt1Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/1);
static LLVMType getIntNTy(MLIRContext *context, unsigned numBits);
static LLVMType getInt1Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/1);
}
static LLVMType getInt8Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/8);
static LLVMType getInt8Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/8);
}
static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
return getInt8Ty(dialect).getPointerTo();
static LLVMType getInt8PtrTy(MLIRContext *context) {
return getInt8Ty(context).getPointerTo();
}
static LLVMType getInt16Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/16);
static LLVMType getInt16Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/16);
}
static LLVMType getInt32Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/32);
static LLVMType getInt32Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/32);
}
static LLVMType getInt64Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/64);
static LLVMType getInt64Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/64);
}
/// Utilities used to generate other miscellaneous types.
@ -187,33 +187,33 @@ public:
static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
return getFunctionTy(result, llvm::None, isVarArg);
}
static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
static LLVMType getStructTy(MLIRContext *context, ArrayRef<LLVMType> elements,
bool isPacked = false);
static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
return getStructTy(dialect, llvm::None, isPacked);
static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) {
return getStructTy(context, llvm::None, isPacked);
}
template <typename... Args>
static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>::type
getStructTy(LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
return getStructTy(&elt1.getDialect(), fields);
return getStructTy(elt1.getContext(), fields);
}
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
/// Void type utilities.
static LLVMType getVoidTy(LLVMDialect *dialect);
static LLVMType getVoidTy(MLIRContext *context);
bool isVoidTy();
// Creation and setting of LLVM's identified struct types
static LLVMType createStructTy(LLVMDialect *dialect,
static LLVMType createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name,
bool isPacked = false);
static LLVMType createStructTy(LLVMDialect *dialect,
static LLVMType createStructTy(MLIRContext *context,
Optional<StringRef> name) {
return createStructTy(dialect, llvm::None, name);
return createStructTy(context, llvm::None, name);
}
static LLVMType createStructTy(ArrayRef<LLVMType> elements,
@ -222,7 +222,7 @@ public:
assert(!elements.empty() &&
"This method may not be invoked with an empty list");
LLVMType ele0 = elements.front();
return createStructTy(&ele0.getDialect(), elements, name, isPacked);
return createStructTy(ele0.getContext(), elements, name, isPacked);
}
template <typename... Args>
@ -231,7 +231,7 @@ public:
createStructTy(StringRef name, LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
Optional<StringRef> opt_name(name);
return createStructTy(&elt1.getDialect(), fields, opt_name);
return createStructTy(elt1.getContext(), fields, opt_name);
}
static LLVMType setStructTyBody(LLVMType structType,

View File

@ -67,14 +67,14 @@ private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
void initializeCachedTypes() {
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
llvmPointerPointerType = llvmPointerType.getPointerTo();
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmInt8Type = LLVM::LLVMType::getInt8Ty(&getContext());
llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
&getContext(), llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
@ -91,7 +91,7 @@ private:
LLVM::LLVMType getIntPtrType() {
return LLVM::LLVMType::getIntNTy(
getLLVMDialect(),
&getContext(),
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
@ -340,7 +340,7 @@ Value GpuLaunchFuncToGpuRuntimeCallsPass::generateKernelNameConstant(
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
return LLVM::createGlobalString(
loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
LLVM::Linkage::Internal, llvmDialect);
LLVM::Linkage::Internal);
}
// Emits LLVM IR to launch a kernel function. Expects the module that contains
@ -378,9 +378,9 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
SmallString<128> nameBuffer(kernelModule.getName());
nameBuffer.append(kGpuBinaryStorageSuffix);
Value data = LLVM::createGlobalString(
loc, builder, nameBuffer.str(), binaryAttr.getValue(),
LLVM::Linkage::Internal, getLLVMDialect());
Value data =
LLVM::createGlobalString(loc, builder, nameBuffer.str(),
binaryAttr.getValue(), LLVM::Linkage::Internal);
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.

View File

@ -89,7 +89,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value zero = nullptr;
if (!workgroupBuffers.empty())
@ -117,7 +117,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
auto type = attribution.getType().cast<MemRefType>();

View File

@ -46,17 +46,17 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto dialect = typeConverter.getDialect();
MLIRContext *context = rewriter.getContext();
Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
case X:
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Y:
newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Z:
newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
default:
return failure();
@ -64,10 +64,10 @@ public:
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});

View File

@ -85,7 +85,7 @@ private:
return operand;
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
operand);
}

View File

@ -57,11 +57,11 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);
auto dialect = typeConverter.getDialect();
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
auto resultTy =
LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));

View File

@ -57,15 +57,12 @@ class VulkanLaunchFuncToVulkanCallsPass
: public ConvertVulkanLaunchFuncToVulkanCallsBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
void initializeCachedTypes() {
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
}
LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
@ -87,7 +84,7 @@ private:
// `!llvm<"{ `element-type`*, `element-type`*, i64,
// [`rank` x i64], [`rank` x i64]}">`.
return LLVM::LLVMType::getStructTy(
llvmDialect,
&getContext(),
{llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
@ -153,7 +150,6 @@ public:
void runOnOperation() override;
private:
LLVM::LLVMDialect *llvmDialect;
LLVM::LLVMType llvmFloatType;
LLVM::LLVMType llvmVoidType;
LLVM::LLVMType llvmPointerType;
@ -245,7 +241,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
// int16_t and bitcast the descriptor.
if (type.isHalfTy()) {
auto memRefTy =
getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
}
@ -324,15 +320,15 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
}
for (unsigned i = 1; i <= 3; i++) {
for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
LLVM::LLVMType::getInt32Ty(llvmDialect),
LLVM::LLVMType::getInt16Ty(llvmDialect),
LLVM::LLVMType::getInt8Ty(llvmDialect),
LLVM::LLVMType::getHalfTy(llvmDialect)}) {
for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
LLVM::LLVMType::getInt32Ty(&getContext()),
LLVM::LLVMType::getInt16Ty(&getContext()),
LLVM::LLVMType::getInt8Ty(&getContext()),
LLVM::LLVMType::getHalfTy(&getContext())}) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (type.isHalfTy())
type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext()));
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),
@ -368,8 +364,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
shaderName, LLVM::Linkage::Internal,
getLLVMDialect());
shaderName, LLVM::Linkage::Internal);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@ -388,7 +383,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
LLVM::Linkage::Internal, getLLVMDialect());
LLVM::Linkage::Internal);
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(

View File

@ -186,15 +186,15 @@ static Type convertStructTypePacked(spirv::StructType type,
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector,
return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
/*isPacked=*/true);
}
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
LLVMTypeConverter &converter, unsigned value) {
unsigned value) {
return rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()),
loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
@ -1002,7 +1002,7 @@ public:
return failure();
Location loc = varOp.getLoc();
Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1);
Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
return success();

View File

@ -199,7 +199,7 @@ llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
@ -211,19 +211,19 @@ Type LLVMTypeConverter::convertIndexType(IndexType type) {
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return LLVM::LLVMType::getFloatTy(llvmDialect);
return LLVM::LLVMType::getFloatTy(&getContext());
case mlir::StandardTypes::F64:
return LLVM::LLVMType::getDoubleTy(llvmDialect);
return LLVM::LLVMType::getDoubleTy(&getContext());
case mlir::StandardTypes::F16:
return LLVM::LLVMType::getHalfTy(llvmDialect);
return LLVM::LLVMType::getHalfTy(&getContext());
case mlir::StandardTypes::BF16: {
return LLVM::LLVMType::getBFloatTy(llvmDialect);
return LLVM::LLVMType::getBFloatTy(&getContext());
}
default:
llvm_unreachable("non-float type in convertFloatType");
@ -238,7 +238,7 @@ static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType});
return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
}
// Except for signatures, MLIR function types are converted into LLVM
@ -274,7 +274,7 @@ LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
/// pointer to descriptor".
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
}
// Function types are converted to LLVM Function types by recursively converting
@ -307,7 +307,7 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// a struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
@ -331,7 +331,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
@ -400,7 +400,7 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
auto rankTy = getIndexType();
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
}
@ -853,11 +853,11 @@ LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
return LLVM::LLVMType::getVoidTy(&getDialect());
return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&getDialect());
return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
}
Value ConvertToLLVMPattern::createIndexConstant(
@ -2025,9 +2025,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
unrankedMemrefs, sizes);
// Get frequently used types.
auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect());
auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect());
MLIRContext *context = builder.getContext();
auto voidType = LLVM::LLVMType::getVoidTy(context);
auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
auto i1Type = LLVM::LLVMType::getInt1Ty(context);
LLVM::LLVMType indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
@ -3168,7 +3169,7 @@ struct GenericAtomicRMWOpLowering
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
@ -3330,13 +3331,13 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
resultTypes.push_back(converted);
}
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.

View File

@ -715,7 +715,7 @@ public:
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@ -832,7 +832,7 @@ public:
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter.convertType(oneDVectorType), extracted,
@ -1074,7 +1074,7 @@ public:
if (failed(successStrides) || !isContiguous)
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@ -1263,11 +1263,10 @@ private:
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
if (value.getType() ==
LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
// Convert i1 (bool) to i32 so we can use the print_i32 method.
// This avoids the need for a print_i1 method with an unclear ABI.
auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto trueVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(1));
auto falseVal = rewriter.create<ConstantOp>(
@ -1303,8 +1302,8 @@ private:
}
// Helper for printer method declaration (first hit) and lookup.
static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
StringRef name, ArrayRef<LLVM::LLVMType> params) {
static Operation *getPrint(Operation *op, StringRef name,
ArrayRef<LLVM::LLVMType> params) {
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
@ -1312,42 +1311,39 @@ private:
OpBuilder moduleBuilder(module.getBodyRegion());
return moduleBuilder.create<LLVM::LLVMFuncOp>(
op->getLoc(), name,
LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
params, /*isVarArg=*/false));
LLVM::LLVMType::getFunctionTy(
LLVM::LLVMType::getVoidTy(op->getContext()), params,
/*isVarArg=*/false));
}
// Helpers for method names.
Operation *getPrintI32(Operation *op) const {
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i32",
LLVM::LLVMType::getInt32Ty(dialect));
return getPrint(op, "print_i32",
LLVM::LLVMType::getInt32Ty(op->getContext()));
}
Operation *getPrintI64(Operation *op) const {
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i64",
LLVM::LLVMType::getInt64Ty(dialect));
return getPrint(op, "print_i64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintFloat(Operation *op) const {
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f32",
LLVM::LLVMType::getFloatTy(dialect));
return getPrint(op, "print_f32",
LLVM::LLVMType::getFloatTy(op->getContext()));
}
Operation *getPrintDouble(Operation *op) const {
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f64",
LLVM::LLVMType::getDoubleTy(dialect));
return getPrint(op, "print_f64",
LLVM::LLVMType::getDoubleTy(op->getContext()));
}
Operation *getPrintOpen(Operation *op) const {
return getPrint(op, typeConverter.getDialect(), "print_open", {});
return getPrint(op, "print_open", {});
}
Operation *getPrintClose(Operation *op) const {
return getPrint(op, typeConverter.getDialect(), "print_close", {});
return getPrint(op, "print_close", {});
}
Operation *getPrintComma(Operation *op) const {
return getPrint(op, typeConverter.getDialect(), "print_comma", {});
return getPrint(op, "print_comma", {});
}
Operation *getPrintNewline(Operation *op) const {
return getPrint(op, typeConverter.getDialect(), "print_newline", {});
return getPrint(op, "print_newline", {});
}
};

View File

@ -101,8 +101,7 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
auto resultType = LLVMType::getInt1Ty(dialect);
auto resultType = LLVMType::getInt1Ty(builder.getContext());
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
@ -393,11 +392,9 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@ -601,11 +598,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
"expected function with 0 or 1 result");
Builder &builder = parser.getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@ -1101,9 +1096,8 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
if (types.empty()) {
if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
MLIRContext *context = parser.getBuilder().getContext();
auto *dialect = context->getRegisteredDialect<LLVMDialect>();
auto arrayType = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
types.push_back(arrayType);
} else {
return parser.emitError(parser.getNameLoc(),
@ -1265,14 +1259,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
llvmInputs.push_back(llvmTy);
}
// Get the dialect from the input type, if any exist. Look it up in the
// context otherwise.
LLVMDialect *dialect =
llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
: &llvmInputs.front().getDialect();
// No output is denoted as "void" in LLVM type system.
LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
: outputs.front().dyn_cast<LLVMType>();
if (!llvmOutput) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
@ -1605,8 +1593,7 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
parser.resolveOperand(val, type, result.operands))
return failure();
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
auto boolType = LLVMType::getInt1Ty(dialect);
auto boolType = LLVMType::getInt1Ty(builder.getContext());
auto resultType = LLVMType::getStructTy(type, boolType);
result.addTypes(resultType);
@ -1777,8 +1764,7 @@ LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
LLVM::Linkage linkage,
LLVM::LLVMDialect *llvmDialect) {
LLVM::Linkage linkage) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constrained in an op");
@ -1788,8 +1774,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Create the global at the entry of the module.
OpBuilder moduleBuilder(module.getBodyRegion());
auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
value.size());
MLIRContext *ctx = builder.getContext();
auto type =
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, linkage, name,
builder.getStringAttr(value));
@ -1797,10 +1784,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
loc, LLVM::LLVMType::getInt64Ty(ctx),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(loc,
LLVM::LLVMType::getInt8PtrTy(llvmDialect),
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}

View File

@ -127,35 +127,35 @@ bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
//----------------------------------------------------------------------------//
// Utilities used to generate floating point types.
LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
return LLVMDoubleType::get(dialect->getContext());
LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
return LLVMDoubleType::get(context);
}
LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
return LLVMFloatType::get(dialect->getContext());
LLVMType LLVMType::getFloatTy(MLIRContext *context) {
return LLVMFloatType::get(context);
}
LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
return LLVMBFloatType::get(dialect->getContext());
LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
return LLVMBFloatType::get(context);
}
LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
return LLVMHalfType::get(dialect->getContext());
LLVMType LLVMType::getHalfTy(MLIRContext *context) {
return LLVMHalfType::get(context);
}
LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) {
return LLVMFP128Type::get(dialect->getContext());
LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
return LLVMFP128Type::get(context);
}
LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) {
return LLVMX86FP80Type::get(dialect->getContext());
LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
return LLVMX86FP80Type::get(context);
}
//----------------------------------------------------------------------------//
// Utilities used to generate integer types.
LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
return LLVMIntegerType::get(dialect->getContext(), numBits);
LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
return LLVMIntegerType::get(context, numBits);
}
//----------------------------------------------------------------------------//
@ -170,9 +170,9 @@ LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
return LLVMFunctionType::get(result, params, isVarArg);
}
LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
LLVMType LLVMType::getStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements, bool isPacked) {
return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
return LLVMStructType::getLiteral(context, elements, isPacked);
}
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
@ -182,8 +182,8 @@ LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
//----------------------------------------------------------------------------//
// Void type utilities.
LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
return LLVMVoidType::get(dialect->getContext());
LLVMType LLVMType::getVoidTy(MLIRContext *context) {
return LLVMVoidType::get(context);
}
bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
@ -191,7 +191,7 @@ bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
//----------------------------------------------------------------------------//
// Creation and setting of LLVM's identified struct types
LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
LLVMType LLVMType::createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name, bool isPacked) {
assert(name.hasValue() &&
@ -200,8 +200,7 @@ LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
std::string stringName = stringNameBase.str();
unsigned counter = 0;
do {
auto type =
LLVMStructType::getIdentified(dialect->getContext(), stringName);
auto type = LLVMStructType::getIdentified(context, stringName);
if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
counter += 1;
stringName =

View File

@ -41,12 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
return parser.getBuilder()
.getContext()
->getRegisteredDialect<LLVM::LLVMDialect>();
}
// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
@ -69,7 +63,7 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
break;
}
auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
@ -77,9 +71,9 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
auto llvmDialect = getLlvmDialect(parser);
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
MLIRContext *context = parser.getBuilder().getContext();
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
@ -92,14 +86,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
}
static LogicalResult verify(MmaOp op) {
auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
MLIRContext *context = op.getContext();
auto f16Ty = LLVM::LLVMType::getHalfTy(context);
auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
auto f32Ty = LLVM::LLVMType::getFloatTy(context);
auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f32x8StructTy = LLVM::LLVMType::getStructTy(
dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
op.getOperandTypes().end());

View File

@ -34,12 +34,6 @@ using namespace ROCDL;
// Parsing for ROCDL ops
//===----------------------------------------------------------------------===//
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
return parser.getBuilder()
.getContext()
->getRegisteredDialect<LLVM::LLVMDialect>();
}
// <operation> ::=
// `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
// result_type`
@ -51,8 +45,9 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
parser.addTypeToList(type, result.types))
return failure();
auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
MLIRContext *context = parser.getBuilder().getContext();
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
return parser.resolveOperands(ops,
{i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
@ -69,8 +64,9 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
return failure();
auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
MLIRContext *context = parser.getBuilder().getContext();
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
if (parser.resolveOperands(ops,