Use llvm.func to define functions with wrapped LLVM IR function type

This function-like operation allows one to define functions that have wrapped
LLVM IR function type, in particular variadic functions. The operation was
added in parallel to the existing lowering flow, this commit only switches the
flow to use it.

Using a custom function type makes the LLVM IR dialect type system more
consistent and avoids complex conversion rules for functions that previously
had to use the built-in function type instead of a wrapped LLVM IR dialect type
and perform conversions during the analysis.

PiperOrigin-RevId: 273910855
This commit is contained in:
Alex Zinenko 2019-10-10 01:33:33 -07:00 committed by A. Unique TensorFlower
parent 309b4556d0
commit 5e7959a353
29 changed files with 324 additions and 307 deletions

View File

@ -152,8 +152,6 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(module, target, patterns, &converter)))
return failure();

View File

@ -138,14 +138,14 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
Type retTy = printfFunc.getType().getResult(0);
Type retTy = printfFunc.getType().getFunctionResultType();
// Create our loop nest now
using namespace edsc;
@ -218,24 +218,23 @@ private:
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
FuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
if (printfFunc)
return printfFunc;
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
Builder builder(module);
OpBuilder builder(module.getBodyRegion());
auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = FuncOp::create(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
module.push_back(printfFunc);
return printfFunc;
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
printfTy,
ArrayRef<NamedAttribute>());
}
};
@ -369,10 +368,10 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
ConversionTarget target(getContext());
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
LLVM::LLVMDialect, StandardOpsDialect>();
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType());
});
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
&typeConverter))) {
emitError(UnknownLoc::get(getModule().getContext()),

View File

@ -137,49 +137,27 @@ Examples:
### Function Signature Conversion
MLIR function type is built into the representation, even the functions in
dialects including a first-class function type must have the built-in MLIR
function type. During the conversion to LLVM IR, function signatures are
converted as follows:
- the outer type remains the built-in MLIR function;
- function arguments are converted individually following these rules;
- function results:
- zero-result functions remain zero-result;
- single-result functions have their result type converted according to
these rules;
- multi-result functions have a single result type of the wrapped LLVM IR
structure type with elements corresponding to the converted original
results.
Rationale: function definitions remain analyzable within MLIR without having to
abstract away the function type. In order to remain consistent with the regular
MLIR functions, we do not introduce a `void` result type since we cannot create
a value of `void` type that MLIR passes might expect to be returned from a
function.
LLVM IR functions are defined by a custom operation. The function itself has a
wrapped LLVM IR function type converted as described above. The function
definition operation uses MLIR syntax.
Examples:
```mlir {.mlir}
// zero-ary function type with no results.
func @foo() -> ()
// remains as is
func @foo() -> ()
// gets LLVM type void().
llvm.func @foo() -> ()
// unary function with one result
// function with one result
func @bar(i32) -> (i64)
// has its argument and result type converted
func @bar(!llvm.type<"i32">) -> !llvm.type<"i64">
// gets converted to LLVM type i64(i32).
func @bar(!llvm.i32) -> !llvm.i64
// binary function with one result
func @baz(i32, f32) -> (i64)
// has its arguments handled separately
func @baz(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"i64">
// binary function with two results
// function with two results
func @qux(i32, f32) -> (i64, f64)
// has its result aggregated into a structure type
func @qux(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"{i64, double}">
func @qux(!llvm.i32, !llvm.float) -> !llvm.type<"{i64, double}">
// function-typed arguments or results in higher-order functions
func @quux(() -> ()) -> (() -> ())

View File

@ -50,6 +50,30 @@ specific LLVM IR type.
All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic
of an operation is that used in LLVM IR prefixed with "`llvm.`".
### LLVM functions
MLIR functions are defined by an operation that is not built into the IR itself.
The LLVM IR dialect provides an `llvm.func` operation to define functions
compatible with LLVM IR. These functions have wrapped LLVM IR function type but
use MLIR syntax to express it. They are required to have exactly one result
type. LLVM function operation is intended to capture additional properties of
LLVM functions, such as linkage and calling convention, that may be modeled
differently by the built-in MLIR function.
```mlir {.mlir}
// The type of @bar is !llvm<"i64 (i64)">
llvm.func @bar(%arg0: !llvm.i64) -> !llvm.i64 {
llvm.return %arg0 : !llvm.i64
}
// Type type of @foo is !llvm<"void (i64)">
// !llvm.void type is omitted
llvm.func @foo(%arg0: !llvm.i64) {
llvm.return
}
```
### LLVM IR operations
The following operations are currently supported. The semantics of these

View File

@ -25,15 +25,12 @@
namespace mlir {
class FuncOp;
class Location;
class ModuleOp;
class OpBuilder;
class Value;
namespace LLVM {
class LLVMDialect;
}
} // namespace LLVM
template <typename T> class OpPassBase;

View File

@ -50,6 +50,12 @@ public:
/// non-standard or non-builtin types.
Type convertType(Type t) override;
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
/// returned, create an LLVM IR structure type with elements that correspond

View File

@ -55,7 +55,7 @@ public:
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(FuncOp function);
static bool isKernel(Operation *op);
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attr) override;

View File

@ -64,6 +64,9 @@ public:
LLVMDialect &getDialect();
llvm::Type *getUnderlyingType() const;
/// Utilities to identify types.
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
/// Array type utilities.
LLVMType getArrayElementType();
unsigned getArrayNumElements();

View File

@ -525,11 +525,15 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
let builders = [
OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
"LLVMType type, ArrayRef<NamedAttribute> attrs, "
"LLVMType type, ArrayRef<NamedAttribute> attrs = {}, "
"ArrayRef<NamedAttributeList> argAttrs = {}">
];
let extraClassDeclaration = [{
// Add an entry block to an empty function, and set up the block arguments
// to match the signature of the function.
Block *addEntryBlock();
LLVMType getType() {
return getAttrOfType<TypeAttr>(getTypeAttrName())
.getValue().cast<LLVMType>();

View File

@ -34,13 +34,14 @@
namespace mlir {
class Attribute;
class FuncOp;
class Location;
class ModuleOp;
class Operation;
namespace LLVM {
class LLVMFuncOp;
// Implementation class for module translation. Holds a reference to the module
// being translated, and the mappings between the original and the translated
// functions, basic blocks and values. It is practically easier to hold these
@ -75,8 +76,8 @@ protected:
private:
LogicalResult convertFunctions();
void convertGlobals();
LogicalResult convertOneFunction(FuncOp func);
void connectPHINodes(FuncOp func);
LogicalResult convertOneFunction(LLVMFuncOp func);
void connectPHINodes(LLVMFuncOp func);
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
template <typename Range>

View File

@ -80,6 +80,7 @@ private:
void initializeCachedTypes() {
const llvm::Module &module = llvmDialect->getLLVMModule();
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmPointerPointerType = llvmPointerType.getPointerTo();
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
@ -89,6 +90,8 @@ private:
llvmDialect, module.getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
LLVM::LLVMType getPointerType() { return llvmPointerType; }
LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
@ -120,7 +123,7 @@ private:
void declareCudaFunctions(Location loc);
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
Value *generateKernelNameConstant(StringRef name, Location &loc,
Value *generateKernelNameConstant(StringRef name, Location loc,
OpBuilder &builder);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
@ -145,6 +148,7 @@ public:
private:
LLVM::LLVMDialect *llvmDialect;
LLVM::LLVMType llvmVoidType;
LLVM::LLVMType llvmPointerType;
LLVM::LLVMType llvmPointerPointerType;
LLVM::LLVMType llvmInt8Type;
@ -160,38 +164,41 @@ private:
// uses void pointers. This is fine as they have the same linkage in C.
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
ModuleOp module = getModule();
Builder builder(module);
if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
module.push_back(
FuncOp::create(loc, cuModuleLoadName,
builder.getFunctionType(
{
getPointerPointerType(), /* CUmodule *module */
getPointerType() /* void *cubin */
},
getCUResultType())));
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(cuModuleLoadName)) {
builder.create<LLVM::LLVMFuncOp>(
loc, cuModuleLoadName,
LLVM::LLVMType::getFunctionTy(
getCUResultType(),
{
getPointerPointerType(), /* CUmodule *module */
getPointerType() /* void *cubin */
},
/*isVarArg=*/false));
}
if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
if (!module.lookupSymbol(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction.
module.push_back(
FuncOp::create(loc, cuModuleGetFunctionName,
builder.getFunctionType(
{
getPointerPointerType(), /* void **function */
getPointerType(), /* void *module */
getPointerType() /* char *name */
},
getCUResultType())));
builder.create<LLVM::LLVMFuncOp>(
loc, cuModuleGetFunctionName,
LLVM::LLVMType::getFunctionTy(
getCUResultType(),
{
getPointerPointerType(), /* void **function */
getPointerType(), /* void *module */
getPointerType() /* char *name */
},
/*isVarArg=*/false));
}
if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
if (!module.lookupSymbol(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
// CUstream.
module.push_back(FuncOp::create(
builder.create<LLVM::LLVMFuncOp>(
loc, cuLaunchKernelName,
builder.getFunctionType(
LLVM::LLVMType::getFunctionTy(
getCUResultType(),
{
getPointerType(), /* void* f */
getIntPtrType(), /* intptr_t gridXDim */
@ -205,32 +212,31 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
getPointerPointerType(), /* void **kernelParams */
getPointerPointerType() /* void **extra */
},
getCUResultType())));
/*isVarArg=*/false));
}
if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
if (!module.lookupSymbol(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream.
module.push_back(FuncOp::create(
builder.create<LLVM::LLVMFuncOp>(
loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */)));
LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
}
if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
module.push_back(
FuncOp::create(loc, cuStreamSynchronizeName,
builder.getFunctionType(
{
getPointerType() /* CUstream stream */
},
getCUResultType())));
if (!module.lookupSymbol(cuStreamSynchronizeName)) {
builder.create<LLVM::LLVMFuncOp>(
loc, cuStreamSynchronizeName,
LLVM::LLVMType::getFunctionTy(getCUResultType(),
getPointerType() /* CUstream stream */,
/*isVarArg=*/false));
}
if (!module.lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr)) {
module.push_back(FuncOp::create(loc, kMcuMemHostRegisterPtr,
builder.getFunctionType(
{
getPointerType(), /* void *ptr */
getInt32Type() /* int32 flags*/
},
{})));
if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kMcuMemHostRegisterPtr,
LLVM::LLVMType::getFunctionTy(getVoidType(),
{
getPointerType(), /* void *ptr */
getInt32Type() /* int32 flags*/
},
/*isVarArg=*/false));
}
}
@ -271,7 +277,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
if (llvmType.isStructTy()) {
auto registerFunc =
getModule().lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr);
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr);
auto zero = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(0));
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
@ -304,7 +310,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
// }
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
StringRef name, Location &loc, OpBuilder &builder) {
StringRef name, Location loc, OpBuilder &builder) {
// Make sure the trailing zero is included in the constant.
std::vector<char> kernelName(name.begin(), name.end());
kernelName.push_back('\0');
@ -355,6 +361,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
<< "missing " << kCubinAnnotation << " attribute";
return signalPassFailure();
}
assert(kernelModule.getName() && "expected a named module");
SmallString<128> nameBuffer(*kernelModule.getName());
nameBuffer.append(kCubinStorageSuffix);
@ -364,7 +371,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
auto cuModuleLoad =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data});
@ -374,20 +382,21 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
auto cuFunction = allocatePointer(builder, loc);
FuncOp cuModuleGetFunction =
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
auto cuModuleGetFunction =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
FuncOp cuGetStreamHelper =
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
auto cuGetStreamHelper =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
// Invoke the function with required arguments.
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
auto cuLaunchKernel =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
@ -404,7 +413,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
paramsArray, /* kernel params */
nullpointer /* extra */});
// Sync on the stream to make it synchronous.
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
auto cuStreamSync =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuStreamSync),
ArrayRef<Value *>(cuStream.getResult(0)));

View File

@ -381,8 +381,6 @@ public:
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
}

View File

@ -95,19 +95,31 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
}
}
// Except for signatures, MLIR function types are converted into LLVM
// pointer-to-function types.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
LLVM::LLVMType converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
return converted.getPointerTo();
}
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
FunctionType type, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Convert argument types one by one and check for errors.
SmallVector<LLVM::LLVMType, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
if (!converted)
for (auto &en : llvm::enumerate(type.getInputs()))
if (failed(convertSignatureArg(en.index(), en.value(), result)))
return {};
argTypes.push_back(unwrap(converted));
}
SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
for (Type type : result.getConvertedTypes())
argTypes.push_back(unwrap(type));
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
@ -118,8 +130,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
.getPointerTo();
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
}
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
@ -249,6 +260,10 @@ public:
&dialect, getModule().getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() const {
return LLVM::LLVMType::getVoidTy(&dialect);
}
// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&dialect);
@ -289,7 +304,16 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
SmallVector<Type, 4> argTypes;
// Pack the result types into a struct.
Type packedResult;
if (type.getNumResults() != 0)
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
return matchFailure();
LLVM::LLVMType resultType = packedResult
? packedResult.cast<LLVM::LLVMType>()
: LLVM::LLVMType::getVoidTy(&dialect);
SmallVector<LLVM::LLVMType, 4> argTypes;
argTypes.reserve(type.getNumInputs());
SmallVector<unsigned, 4> promotedArgIndices;
promotedArgIndices.reserve(type.getNumInputs());
@ -297,14 +321,15 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
// Convert the original function arguments. Struct arguments are promoted to
// pointer to struct arguments to allow calling external functions with
// various ABIs (e.g. compiled from C/C++ on platform X).
TypeConverter::SignatureConversion result(type.getNumInputs());
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
for (auto en : llvm::enumerate(type.getInputs())) {
auto t = en.value();
auto converted = lowering.convertType(t);
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return matchFailure();
if (t.isa<MemRefType>()) {
converted = converted.cast<LLVM::LLVMType>().getPointerTo();
converted = converted.getPointerTo();
promotedArgIndices.push_back(en.index());
}
argTypes.push_back(converted);
@ -312,21 +337,24 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
result.addInputs(idx, argTypes[idx]);
// Pack the result types into a struct.
Type packedResult;
if (type.getNumResults() != 0) {
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
return matchFailure();
auto llvmType = LLVM::LLVMType::getFunctionTy(
resultType, argTypes, varargsAttr && varargsAttr.getValue());
// Only retain those attributes that are not constructed by build.
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : funcOp.getAttrs()) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is("std.varargs"))
continue;
attributes.push_back(attr);
}
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
// Create an LLVM funcion.
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
op->getLoc(), funcOp.getName(), llvmType, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(FunctionType::get(
result.getConvertedTypes(),
packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
@ -627,13 +655,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc =
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
module.push_back(mallocFunc);
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
/*isVarArg=*/false));
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
@ -792,12 +820,14 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
FuncOp freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
auto freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
op->getParentOfType<ModuleOp>().push_back(freeFunc);
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
/*isVarArg=*/false));
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
@ -1373,9 +1403,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter->isSignatureLegal(op.getType());
});
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
signalPassFailure();
}

View File

@ -6,5 +6,5 @@ add_llvm_library(MLIRGPU
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
)
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR LLVMSupport)
target_link_libraries(MLIRGPU MLIRIR MLIRStandardOps LLVMSupport)
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR MLIRLLVMIR LLVMSupport)
target_link_libraries(MLIRGPU MLIRIR MLIRLLVMIR MLIRStandardOps LLVMSupport)

View File

@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
@ -37,9 +38,8 @@ using namespace mlir::gpu;
StringRef GPUDialect::getDialectName() { return "gpu"; }
bool GPUDialect::isKernel(FuncOp function) {
UnitAttr isKernelAttr =
function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
@ -92,18 +92,25 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Check that `launch_func` refers to a well-formed kernel function.
StringRef kernelName = launchOp.kernel();
auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(kernelName);
if (!kernelFunction)
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelStdFunction && !kernelLLVMFunction)
return launchOp.emitOpError("kernel function '")
<< kernelName << "' is undefined";
if (!kernelFunction.getAttrOfType<mlir::UnitAttr>(
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName()))
return launchOp.emitOpError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
if (launchOp.getNumKernelOperands() != kernelFunction.getNumArguments())
return launchOp.emitOpError("got ") << launchOp.getNumKernelOperands()
<< " kernel operands but expected "
<< kernelFunction.getNumArguments();
unsigned actualNumArguments = launchOp.getNumKernelOperands();
unsigned expectedNumArguments = kernelLLVMFunction
? kernelLLVMFunction.getNumArguments()
: kernelStdFunction.getNumArguments();
if (expectedNumArguments != actualNumArguments)
return launchOp.emitOpError("got ")
<< actualNumArguments << " kernel operands but expected "
<< expectedNumArguments;
// Due to the ordering of the current impl of lowering and LLVMLowering,
// type checks need to be temporarily disabled.

View File

@ -178,15 +178,17 @@ public:
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
auto mallocFunc = module.lookupSymbol<LLVMFuncOp>("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc =
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
module.push_back(mallocFunc);
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty,
/*isVarArg=*/false));
}
// Get MLIR types for injecting element pointer.
@ -257,15 +259,18 @@ public:
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
auto freeFunc = module.lookupSymbol<LLVMFuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
module.push_back(freeFunc);
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
freeFunc = moduleBuilder.create<LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy,
/*isVarArg=*/false));
}
// Emit MLIR for buffer_dealloc.

View File

@ -177,7 +177,7 @@ compileAndExecute(ModuleOp module, StringRef entryPoint,
static Error compileAndExecuteVoidFunction(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.getBlocks().empty())
return make_string_error("entry point not found");
void *empty = nullptr;
@ -187,22 +187,14 @@ static Error compileAndExecuteVoidFunction(
static Error compileAndExecuteSingleFloatReturnFunction(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal()) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
}
if (!mainFunction.getType().getInputs().empty())
if (mainFunction.getType().getFunctionNumParams() != 0)
return make_string_error("function inputs not supported");
if (mainFunction.getType().getResults().size() != 1)
return make_string_error("only single f32 function result supported");
auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
if (!t)
return make_string_error("only single llvm.f32 function result supported");
auto *llvmTy = t.getUnderlyingType();
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
if (!mainFunction.getType().getFunctionResultType().isFloatTy())
return make_string_error("only single llvm.f32 function result supported");
float res;

View File

@ -25,6 +25,7 @@ add_llvm_library(MLIRTargetNVVMIR
target_link_libraries(MLIRTargetNVVMIR
MLIRGPU
MLIRIR
MLIRLLVMIR
MLIRNVVMIR
MLIRTargetLLVMIRModuleTranslation
)
@ -39,6 +40,7 @@ add_llvm_library(MLIRTargetROCDLIR
target_link_libraries(MLIRTargetROCDLIR
MLIRGPU
MLIRIR
MLIRLLVMIR
MLIRROCDLIR
MLIRTargetLLVMIRModuleTranslation
)

View File

@ -23,8 +23,8 @@
#include "mlir/Target/NVVMIR.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
@ -66,11 +66,13 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
ModuleTranslation translation(m);
auto llvmModule =
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
if (!llvmModule)
return llvmModule;
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
for (FuncOp func : m.getOps<FuncOp>()) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
if (!gpu::GPUDialect::isKernel(func))
continue;
auto *llvmFunc = llvmModule->getFunction(func.getName());

View File

@ -23,6 +23,7 @@
#include "mlir/Target/ROCDLIR.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
@ -93,7 +94,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
// foreach GPU kernel
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
for (FuncOp func : m.getOps<FuncOp>()) {
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;

View File

@ -39,38 +39,6 @@
namespace mlir {
namespace LLVM {
// Convert an MLIR function type to LLVM IR. Arguments of the function must of
// MLIR LLVM IR dialect types. Use `loc` as a location when reporting errors.
// Return nullptr on errors.
static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
FunctionType type, Location loc,
bool isVarArgs) {
assert(type && "expected non-null type");
if (type.getNumResults() > 1)
return emitError(loc, "LLVM functions can only have 0 or 1 result"),
nullptr;
SmallVector<llvm::Type *, 8> argTypes;
argTypes.reserve(type.getNumInputs());
for (auto t : type.getInputs()) {
auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
return emitError(loc, "non-LLVM function argument type"), nullptr;
argTypes.push_back(wrappedLLVMType.getUnderlyingType());
}
if (type.getNumResults() == 0)
return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
isVarArgs);
auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!wrappedResultType)
return emitError(loc, "non-LLVM function result"), nullptr;
return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
argTypes, isVarArgs);
}
// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
// This currently supports integer, floating point, splat and dense element
// attributes and combinations thereof. In case of error, report it to `loc`
@ -362,7 +330,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
: terminator.getSuccessorOperand(1, index);
}
void ModuleTranslation::connectPHINodes(FuncOp func) {
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
@ -393,7 +361,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
}
// Sort function blocks topologically.
static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
// For each blocks that has not been visited yet (i.e. that has no
// predecessors), add it to the list and traverse its successors in DFS
// preorder.
@ -407,7 +375,7 @@ static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
return blocks;
}
LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Clear the block and value mappings, they are only relevant within one
// function.
blockMapping.clear();
@ -460,24 +428,17 @@ LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
LogicalResult ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
llvm::FunctionType *functionType =
convertFunctionType(llvmModule->getContext(), function.getType(),
function.getLoc(), isVarArgs);
if (!functionType)
return failure();
llvm::FunctionCallee llvmFuncCst =
llvmModule->getOrInsertFunction(function.getName(), functionType);
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(),
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
functionMapping[function.getName()] =
cast<llvm::Function>(llvmFuncCst.getCallee());
}
// Convert functions.
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
// Ignore external functions.
if (function.isExternal())
continue;

View File

@ -10,7 +10,7 @@ module attributes {gpu.container_module} {
attributes { gpu.kernel }
}
func @foo() {
llvm.func @foo() {
%0 = "op"() : () -> (!llvm.float)
%1 = "op"() : () -> (!llvm<"float*">)
%cst = constant 8 : index
@ -29,7 +29,7 @@ module attributes {gpu.container_module} {
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernel_module }
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
return
llvm.return
}
}

View File

@ -1,8 +1,9 @@
// RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s
// CHECK: attributes {gpu.kernel_module, nvvm.cubin = "CUBIN"}
module @kernels attributes {gpu.kernel_module} {
func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
module @foo attributes {gpu.kernel_module} {
llvm.func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
// CHECK: attributes {gpu.kernel}
attributes { gpu.kernel } {
llvm.return
}
@ -10,15 +11,15 @@ module @kernels attributes {gpu.kernel_module} {
// -----
module attributes {gpu.kernel_module} {
module @bar attributes {gpu.kernel_module} {
// CHECK: func @kernel_a
func @kernel_a()
llvm.func @kernel_a()
attributes { gpu.kernel } {
llvm.return
}
// CHECK: func @kernel_b
func @kernel_b()
llvm.func @kernel_b()
attributes { gpu.kernel } {
llvm.return
}

View File

@ -1,7 +1,7 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: @fmuladd_test
func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
llvm.func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.fmuladd.f32.f32.f32
"llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.fmuladd.v8f32.v8f32.v8f32
@ -10,7 +10,7 @@ func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x fl
}
// CHECK-LABEL: @exp_test
func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
llvm.func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.exp.f32
"llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.exp.v8f32

View File

@ -29,7 +29,7 @@ llvm.mlir.global @int_global_undef() : !llvm.i64
//
// CHECK: declare i8* @malloc(i64)
func @malloc(!llvm.i64) -> !llvm<"i8*">
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
// CHECK: declare void @free(i8*)
@ -41,12 +41,12 @@ func @malloc(!llvm.i64) -> !llvm<"i8*">
// CHECK-LABEL: define void @empty() {
// CHECK-NEXT: ret void
// CHECK-NEXT: }
func @empty() {
llvm.func @empty() {
llvm.return
}
// CHECK-LABEL: @global_refs
func @global_refs() {
llvm.func @global_refs() {
// Check load from globals.
// CHECK: load i32, i32* @i32_global
%0 = llvm.mlir.addressof @i32_global : !llvm<"i32*">
@ -63,11 +63,11 @@ func @global_refs() {
}
// CHECK-LABEL: declare void @body(i64)
func @body(!llvm.i64)
llvm.func @body(!llvm.i64)
// CHECK-LABEL: define void @simple_loop() {
func @simple_loop() {
llvm.func @simple_loop() {
// CHECK: br label %[[SIMPLE_bb1:[0-9]+]]
llvm.br ^bb1
@ -107,7 +107,7 @@ func @simple_loop() {
// CHECK-NEXT: call void @simple_loop()
// CHECK-NEXT: ret void
// CHECK-NEXT: }
func @simple_caller() {
llvm.func @simple_caller() {
llvm.call @simple_loop() : () -> ()
llvm.return
}
@ -124,20 +124,20 @@ func @simple_caller() {
// CHECK-NEXT: call void @more_imperfectly_nested_loops()
// CHECK-NEXT: ret void
// CHECK-NEXT: }
func @ml_caller() {
llvm.func @ml_caller() {
llvm.call @simple_loop() : () -> ()
llvm.call @more_imperfectly_nested_loops() : () -> ()
llvm.return
}
// CHECK-LABEL: declare i64 @body_args(i64)
func @body_args(!llvm.i64) -> !llvm.i64
llvm.func @body_args(!llvm.i64) -> !llvm.i64
// CHECK-LABEL: declare i32 @other(i64, i32)
func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
llvm.func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
// CHECK-LABEL: define i32 @func_args(i32 {{%.*}}, i32 {{%.*}}) {
// CHECK-NEXT: br label %[[ARGS_bb1:[0-9]+]]
func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
llvm.func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
%0 = llvm.mlir.constant(0 : i32) : !llvm.i32
llvm.br ^bb1
@ -182,17 +182,17 @@ func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
}
// CHECK: declare void @pre(i64)
func @pre(!llvm.i64)
llvm.func @pre(!llvm.i64)
// CHECK: declare void @body2(i64, i64)
func @body2(!llvm.i64, !llvm.i64)
llvm.func @body2(!llvm.i64, !llvm.i64)
// CHECK: declare void @post(i64)
func @post(!llvm.i64)
llvm.func @post(!llvm.i64)
// CHECK-LABEL: define void @imperfectly_nested_loops() {
// CHECK-NEXT: br label %[[IMPER_bb1:[0-9]+]]
func @imperfectly_nested_loops() {
llvm.func @imperfectly_nested_loops() {
llvm.br ^bb1
// CHECK: [[IMPER_bb1]]:
@ -259,10 +259,10 @@ func @imperfectly_nested_loops() {
}
// CHECK: declare void @mid(i64)
func @mid(!llvm.i64)
llvm.func @mid(!llvm.i64)
// CHECK: declare void @body3(i64, i64)
func @body3(!llvm.i64, !llvm.i64)
llvm.func @body3(!llvm.i64, !llvm.i64)
// A complete function transformation check.
// CHECK-LABEL: define void @more_imperfectly_nested_loops() {
@ -306,7 +306,7 @@ func @body3(!llvm.i64, !llvm.i64)
// CHECK: 21: ; preds = %2
// CHECK-NEXT: ret void
// CHECK-NEXT: }
func @more_imperfectly_nested_loops() {
llvm.func @more_imperfectly_nested_loops() {
llvm.br ^bb1
^bb1: // pred: ^bb0
%0 = llvm.mlir.constant(0 : index) : !llvm.i64
@ -359,7 +359,7 @@ func @more_imperfectly_nested_loops() {
//
// CHECK-LABEL: define void @memref_alloc()
func @memref_alloc() {
llvm.func @memref_alloc() {
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 400)
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
// CHECK-NEXT: %{{[0-9]+}} = insertvalue { float* } undef, float* %{{[0-9]+}}, 0
@ -377,10 +377,10 @@ func @memref_alloc() {
}
// CHECK-LABEL: declare i64 @get_index()
func @get_index() -> !llvm.i64
llvm.func @get_index() -> !llvm.i64
// CHECK-LABEL: define void @store_load_static()
func @store_load_static() {
llvm.func @store_load_static() {
^bb0:
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 40)
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
@ -448,7 +448,7 @@ func @store_load_static() {
}
// CHECK-LABEL: define void @store_load_dynamic(i64 {{%.*}})
func @store_load_dynamic(%arg0: !llvm.i64) {
llvm.func @store_load_dynamic(%arg0: !llvm.i64) {
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 %{{[0-9]+}})
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
@ -518,7 +518,7 @@ func @store_load_dynamic(%arg0: !llvm.i64) {
}
// CHECK-LABEL: define void @store_load_mixed(i64 {{%.*}})
func @store_load_mixed(%arg0: !llvm.i64) {
llvm.func @store_load_mixed(%arg0: !llvm.i64) {
%0 = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK-NEXT: %{{[0-9]+}} = mul i64 2, %{{[0-9]+}}
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
@ -603,7 +603,7 @@ func @store_load_mixed(%arg0: !llvm.i64) {
}
// CHECK-LABEL: define { float*, i64 } @memref_args_rets({ float* } {{%.*}}, { float*, i64 } {{%.*}}, { float*, i64 } {{%.*}}) {
func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
llvm.func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
%0 = llvm.mlir.constant(7 : index) : !llvm.i64
// CHECK-NEXT: %{{[0-9]+}} = call i64 @get_index()
%1 = llvm.call @get_index() : () -> !llvm.i64
@ -657,7 +657,7 @@ func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }
// CHECK-LABEL: define i64 @memref_dim({ float*, i64, i64 } {{%.*}})
func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
llvm.func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
// Expecting this to create an LLVM constant.
%0 = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %2 = extractvalue { float*, i64, i64 } %0, 1
@ -678,12 +678,12 @@ func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
llvm.return %6 : !llvm.i64
}
func @get_i64() -> !llvm.i64
func @get_f32() -> !llvm.float
func @get_memref() -> !llvm<"{ float*, i64, i64 }">
llvm.func @get_i64() -> !llvm.i64
llvm.func @get_f32() -> !llvm.float
llvm.func @get_memref() -> !llvm<"{ float*, i64, i64 }">
// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
llvm.func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
%0 = llvm.call @get_i64() : () -> !llvm.i64
%1 = llvm.call @get_f32() : () -> !llvm.float
%2 = llvm.call @get_memref() : () -> !llvm<"{ float*, i64, i64 }">
@ -700,7 +700,7 @@ func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
// CHECK-LABEL: define void @multireturn_caller() {
func @multireturn_caller() {
llvm.func @multireturn_caller() {
// CHECK-NEXT: %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
// CHECK-NEXT: [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
// CHECK-NEXT: [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
@ -734,7 +734,7 @@ func @multireturn_caller() {
}
// CHECK-LABEL: define <4 x float> @vector_ops(<4 x float> {{%.*}}, <4 x i1> {{%.*}}, <4 x i64> {{%.*}}) {
func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
llvm.func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
%0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
// CHECK-NEXT: %4 = fadd <4 x float> %0, <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
%1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
@ -763,7 +763,7 @@ func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !
}
// CHECK-LABEL: @ops
func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
// CHECK-NEXT: fsub float %0, %1
%0 = llvm.fsub %arg0, %arg1 : !llvm.float
// CHECK-NEXT: %6 = sub i32 %2, %3
@ -811,7 +811,7 @@ func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm
//
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) {
func @indirect_const_call(%arg0: !llvm.i64) {
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
// CHECK-NEXT: call void @body(i64 %0)
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
llvm.call %0(%arg0) : (!llvm.i64) -> ()
@ -820,7 +820,7 @@ func @indirect_const_call(%arg0: !llvm.i64) {
}
// CHECK-LABEL: define i32 @indirect_call(i32 (float)* {{%.*}}, float {{%.*}}) {
func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
llvm.func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
// CHECK-NEXT: %3 = call i32 %0(float %1)
%0 = llvm.call %arg0(%arg1) : (!llvm.float) -> !llvm.i32
// CHECK-NEXT: ret i32 %3
@ -833,7 +833,7 @@ func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i
//
// CHECK-LABEL: define void @cond_br_arguments(i1 {{%.*}}, i1 {{%.*}}) {
func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
llvm.func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
// CHECK-NEXT: br i1 %0, label %3, label %5
llvm.cond_br %arg0, ^bb1(%arg0 : !llvm.i1), ^bb2
@ -850,15 +850,14 @@ func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
}
// CHECK-LABEL: define void @llvm_noalias(float* noalias {{%*.}}) {
func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
llvm.func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
llvm.return
}
// CHECK-LABEL: @llvm_varargs(...)
func @llvm_varargs()
attributes {std.varargs = true}
llvm.func @llvm_varargs(...)
func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
llvm.func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
// CHECK: %2 = inttoptr i32 %0 to i32*
// CHECK-NEXT: %3 = ptrtoint i32* %2 to i32
%1 = llvm.inttoptr %arg0 : !llvm.i32 to !llvm<"i32*">
@ -866,19 +865,19 @@ func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
llvm.return %2 : !llvm.i32
}
func @stringconstant() -> !llvm<"i8*"> {
llvm.func @stringconstant() -> !llvm<"i8*"> {
%1 = llvm.mlir.constant("Hello world!") : !llvm<"i8*">
// CHECK: ret [12 x i8] c"Hello world!"
llvm.return %1 : !llvm<"i8*">
}
func @noreach() {
llvm.func @noreach() {
// CHECK: unreachable
llvm.unreachable
}
// CHECK-LABEL: define void @fcmp
func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
llvm.func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
// CHECK: fcmp oeq float %0, %1
// CHECK-NEXT: fcmp ogt float %0, %1
// CHECK-NEXT: fcmp oge float %0, %1
@ -911,7 +910,7 @@ func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
}
// CHECK-LABEL: @vect
func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
llvm.func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
// CHECK-NEXT: extractelement <4 x float> {{.*}}, i32 {{.*}}
// CHECK-NEXT: insertelement <4 x float> {{.*}}, float %2, i32 {{.*}}
// CHECK-NEXT: shufflevector <4 x float> {{.*}}, <4 x float> {{.*}}, <5 x i32> <i32 0, i32 0, i32 0, i32 0, i32 7>
@ -922,7 +921,7 @@ func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
}
// CHECK-LABEL: @alloca
func @alloca(%size : !llvm.i64) {
llvm.func @alloca(%size : !llvm.i64) {
// CHECK: alloca
// CHECK-NOT: align
llvm.alloca %size x !llvm.i32 {alignment = 0} : (!llvm.i64) -> (!llvm<"i32*">)
@ -932,13 +931,14 @@ func @alloca(%size : !llvm.i64) {
}
// CHECK-LABEL: @constants
func @constants() -> !llvm<"<4 x float>"> {
llvm.func @constants() -> !llvm<"<4 x float>"> {
// CHECK: ret <4 x float> <float 4.2{{0*}}e+01, float 0.{{0*}}e+00, float 0.{{0*}}e+00, float 0.{{0*}}e+00>
%0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : !llvm<"<4 x float>">
llvm.return %0 : !llvm<"<4 x float>">
}
func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
// CHECK-LABEL: @fp_casts
llvm.func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
// CHECK: fptrunc double {{.*}} to float
%a = llvm.fptrunc %fp2 : !llvm<"double"> to !llvm<"float">
// CHECK: fpext float {{.*}} to double
@ -948,7 +948,8 @@ func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
llvm.return %c : !llvm.i16
}
func @integer_extension_and_truncation(%a : !llvm.i32) {
// CHECK-LABEL: @integer_extension_and_truncation
llvm.func @integer_extension_and_truncation(%a : !llvm.i32) {
// CHECK: sext i32 {{.*}} to i64
// CHECK: zext i32 {{.*}} to i64
// CHECK: trunc i32 {{.*}} to i16

View File

@ -1,7 +1,7 @@
// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
func @nvvm_special_regs() -> !llvm.i32 {
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
llvm.func @nvvm_special_regs() -> !llvm.i32 {
// CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
%2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
@ -32,13 +32,13 @@ func @nvvm_special_regs() -> !llvm.i32 {
llvm.return %1 : !llvm.i32
}
func @llvm.nvvm.barrier0() {
llvm.func @llvm.nvvm.barrier0() {
// CHECK: call void @llvm.nvvm.barrier0()
nvvm.barrier0
llvm.return
}
func @nvvm_shfl(
llvm.func @nvvm_shfl(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
@ -48,7 +48,7 @@ func @nvvm_shfl(
llvm.return %6 : !llvm.i32
}
func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
llvm.return %3 : !llvm.i32
@ -56,7 +56,7 @@ func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
func @kernel_func() attributes {gpu.kernel} {
llvm.func @kernel_func() attributes {gpu.kernel} {
llvm.return
}

View File

@ -1,6 +1,6 @@
// RUN: mlir-translate -mlir-to-rocdlir %s | FileCheck %s
func @rocdl_special_regs() -> !llvm.i32 {
llvm.func @rocdl_special_regs() -> !llvm.i32 {
// CHECK-LABEL: rocdl_special_regs
// CHECK: call i32 @llvm.amdgcn.workitem.id.x()
%1 = rocdl.workitem.id.x : !llvm.i32
@ -29,7 +29,7 @@ func @rocdl_special_regs() -> !llvm.i32 {
llvm.return %1 : !llvm.i32
}
func @kernel_func() attributes {gpu.kernel} {
llvm.func @kernel_func() attributes {gpu.kernel} {
// CHECK-LABEL: amdgpu_kernel void @kernel_func
llvm.return
}

View File

@ -12,12 +12,12 @@
// RUN: rm %T/test.o
// Declarations of C library functions.
func @fabsf(!llvm.float) -> !llvm.float
func @malloc(!llvm.i64) -> !llvm<"i8*">
func @free(!llvm<"i8*">)
llvm.func @fabsf(!llvm.float) -> !llvm.float
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
llvm.func @free(!llvm<"i8*">)
// Check that a simple function with a nested call works.
func @main() -> !llvm.float {
llvm.func @main() -> !llvm.float {
%0 = llvm.mlir.constant(-4.200000e+02 : f32) : !llvm.float
%1 = llvm.call @fabsf(%0) : (!llvm.float) -> !llvm.float
llvm.return %1 : !llvm.float
@ -25,13 +25,13 @@ func @main() -> !llvm.float {
// CHECK: 4.200000e+02
// Helper typed functions wrapping calls to "malloc" and "free".
func @allocation() -> !llvm<"float*"> {
llvm.func @allocation() -> !llvm<"float*"> {
%0 = llvm.mlir.constant(4 : index) : !llvm.i64
%1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*">
%2 = llvm.bitcast %1 : !llvm<"i8*"> to !llvm<"float*">
llvm.return %2 : !llvm<"float*">
}
func @deallocation(%arg0: !llvm<"float*">) {
llvm.func @deallocation(%arg0: !llvm<"float*">) {
%0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*">
llvm.call @free(%0) : (!llvm<"i8*">) -> ()
llvm.return
@ -39,7 +39,7 @@ func @deallocation(%arg0: !llvm<"float*">) {
// Check that allocation and deallocation works, and that a custom entry point
// works.
func @foo() -> !llvm.float {
llvm.func @foo() -> !llvm.float {
%0 = llvm.call @allocation() : () -> !llvm<"float*">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.mlir.constant(1.234000e+03 : f32) : !llvm.float

View File

@ -1,6 +1,6 @@
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
func @print_0d() {
%f = constant 2.00000e+00 : f32