forked from OSchip/llvm-project
Use llvm.func to define functions with wrapped LLVM IR function type
This function-like operation allows one to define functions that have wrapped LLVM IR function type, in particular variadic functions. The operation was added in parallel to the existing lowering flow, this commit only switches the flow to use it. Using a custom function type makes the LLVM IR dialect type system more consistent and avoids complex conversion rules for functions that previously had to use the built-in function type instead of a wrapped LLVM IR dialect type and perform conversions during the analysis. PiperOrigin-RevId: 273910855
This commit is contained in:
parent
309b4556d0
commit
5e7959a353
|
@ -152,8 +152,6 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
|
|||
ConversionTarget target(*module.getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -138,14 +138,14 @@ public:
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
||||
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
||||
|
||||
auto print = cast<toy::PrintOp>(op);
|
||||
auto loc = print.getLoc();
|
||||
// We will operate on a MemRef abstraction, we use a type.cast to get one
|
||||
// if our operand is still a Toy array.
|
||||
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
||||
Type retTy = printfFunc.getType().getResult(0);
|
||||
Type retTy = printfFunc.getType().getFunctionResultType();
|
||||
|
||||
// Create our loop nest now
|
||||
using namespace edsc;
|
||||
|
@ -218,24 +218,23 @@ private:
|
|||
|
||||
/// Return the prototype declaration for printf in the module, create it if
|
||||
/// necessary.
|
||||
FuncOp getPrintf(ModuleOp module) const {
|
||||
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
|
||||
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
|
||||
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
|
||||
if (printfFunc)
|
||||
return printfFunc;
|
||||
|
||||
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
||||
Builder builder(module);
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
auto *dialect =
|
||||
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
||||
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
|
||||
printfFunc = FuncOp::create(builder.getUnknownLoc(), "printf", printfTy);
|
||||
// It should be variadic, but we don't support it fully just yet.
|
||||
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
|
||||
module.push_back(printfFunc);
|
||||
return printfFunc;
|
||||
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
|
||||
/*isVarArg=*/true);
|
||||
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
|
||||
printfTy,
|
||||
ArrayRef<NamedAttribute>());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -369,10 +368,10 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
|||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
|
||||
LLVM::LLVMDialect, StandardOpsDialect>();
|
||||
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType());
|
||||
});
|
||||
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
||||
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
|
||||
&typeConverter))) {
|
||||
emitError(UnknownLoc::get(getModule().getContext()),
|
||||
|
|
|
@ -137,49 +137,27 @@ Examples:
|
|||
|
||||
### Function Signature Conversion
|
||||
|
||||
MLIR function type is built into the representation, even the functions in
|
||||
dialects including a first-class function type must have the built-in MLIR
|
||||
function type. During the conversion to LLVM IR, function signatures are
|
||||
converted as follows:
|
||||
|
||||
- the outer type remains the built-in MLIR function;
|
||||
- function arguments are converted individually following these rules;
|
||||
- function results:
|
||||
- zero-result functions remain zero-result;
|
||||
- single-result functions have their result type converted according to
|
||||
these rules;
|
||||
- multi-result functions have a single result type of the wrapped LLVM IR
|
||||
structure type with elements corresponding to the converted original
|
||||
results.
|
||||
|
||||
Rationale: function definitions remain analyzable within MLIR without having to
|
||||
abstract away the function type. In order to remain consistent with the regular
|
||||
MLIR functions, we do not introduce a `void` result type since we cannot create
|
||||
a value of `void` type that MLIR passes might expect to be returned from a
|
||||
function.
|
||||
LLVM IR functions are defined by a custom operation. The function itself has a
|
||||
wrapped LLVM IR function type converted as described above. The function
|
||||
definition operation uses MLIR syntax.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// zero-ary function type with no results.
|
||||
func @foo() -> ()
|
||||
// remains as is
|
||||
func @foo() -> ()
|
||||
// gets LLVM type void().
|
||||
llvm.func @foo() -> ()
|
||||
|
||||
// unary function with one result
|
||||
// function with one result
|
||||
func @bar(i32) -> (i64)
|
||||
// has its argument and result type converted
|
||||
func @bar(!llvm.type<"i32">) -> !llvm.type<"i64">
|
||||
// gets converted to LLVM type i64(i32).
|
||||
func @bar(!llvm.i32) -> !llvm.i64
|
||||
|
||||
// binary function with one result
|
||||
func @baz(i32, f32) -> (i64)
|
||||
// has its arguments handled separately
|
||||
func @baz(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"i64">
|
||||
|
||||
// binary function with two results
|
||||
// function with two results
|
||||
func @qux(i32, f32) -> (i64, f64)
|
||||
// has its result aggregated into a structure type
|
||||
func @qux(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"{i64, double}">
|
||||
func @qux(!llvm.i32, !llvm.float) -> !llvm.type<"{i64, double}">
|
||||
|
||||
// function-typed arguments or results in higher-order functions
|
||||
func @quux(() -> ()) -> (() -> ())
|
||||
|
|
|
@ -50,6 +50,30 @@ specific LLVM IR type.
|
|||
All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic
|
||||
of an operation is that used in LLVM IR prefixed with "`llvm.`".
|
||||
|
||||
### LLVM functions
|
||||
|
||||
MLIR functions are defined by an operation that is not built into the IR itself.
|
||||
The LLVM IR dialect provides an `llvm.func` operation to define functions
|
||||
compatible with LLVM IR. These functions have wrapped LLVM IR function type but
|
||||
use MLIR syntax to express it. They are required to have exactly one result
|
||||
type. LLVM function operation is intended to capture additional properties of
|
||||
LLVM functions, such as linkage and calling convention, that may be modeled
|
||||
differently by the built-in MLIR function.
|
||||
|
||||
```mlir {.mlir}
|
||||
// The type of @bar is !llvm<"i64 (i64)">
|
||||
llvm.func @bar(%arg0: !llvm.i64) -> !llvm.i64 {
|
||||
llvm.return %arg0 : !llvm.i64
|
||||
}
|
||||
|
||||
// Type type of @foo is !llvm<"void (i64)">
|
||||
// !llvm.void type is omitted
|
||||
llvm.func @foo(%arg0: !llvm.i64) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### LLVM IR operations
|
||||
|
||||
The following operations are currently supported. The semantics of these
|
||||
|
|
|
@ -25,15 +25,12 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class Location;
|
||||
class ModuleOp;
|
||||
class OpBuilder;
|
||||
class Value;
|
||||
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
}
|
||||
} // namespace LLVM
|
||||
|
||||
template <typename T> class OpPassBase;
|
||||
|
||||
|
|
|
@ -50,6 +50,12 @@ public:
|
|||
/// non-standard or non-builtin types.
|
||||
Type convertType(Type t) override;
|
||||
|
||||
/// Convert a function type. The arguments and results are converted one by
|
||||
/// one and results are packed into a wrapped LLVM IR structure type. `result`
|
||||
/// is populated with argument mapping.
|
||||
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
|
||||
SignatureConversion &result);
|
||||
|
||||
/// Convert a non-empty list of types to be returned from a function into a
|
||||
/// supported LLVM IR type. In particular, if more than one values is
|
||||
/// returned, create an LLVM IR structure type with elements that correspond
|
||||
|
|
|
@ -55,7 +55,7 @@ public:
|
|||
|
||||
/// Returns whether the given function is a kernel function, i.e., has the
|
||||
/// 'gpu.kernel' attribute.
|
||||
static bool isKernel(FuncOp function);
|
||||
static bool isKernel(Operation *op);
|
||||
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) override;
|
||||
|
|
|
@ -64,6 +64,9 @@ public:
|
|||
LLVMDialect &getDialect();
|
||||
llvm::Type *getUnderlyingType() const;
|
||||
|
||||
/// Utilities to identify types.
|
||||
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
|
||||
|
||||
/// Array type utilities.
|
||||
LLVMType getArrayElementType();
|
||||
unsigned getArrayNumElements();
|
||||
|
|
|
@ -525,11 +525,15 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
|
|||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
|
||||
"LLVMType type, ArrayRef<NamedAttribute> attrs, "
|
||||
"LLVMType type, ArrayRef<NamedAttribute> attrs = {}, "
|
||||
"ArrayRef<NamedAttributeList> argAttrs = {}">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Add an entry block to an empty function, and set up the block arguments
|
||||
// to match the signature of the function.
|
||||
Block *addEntryBlock();
|
||||
|
||||
LLVMType getType() {
|
||||
return getAttrOfType<TypeAttr>(getTypeAttrName())
|
||||
.getValue().cast<LLVMType>();
|
||||
|
|
|
@ -34,13 +34,14 @@
|
|||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class FuncOp;
|
||||
class Location;
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
|
||||
namespace LLVM {
|
||||
|
||||
class LLVMFuncOp;
|
||||
|
||||
// Implementation class for module translation. Holds a reference to the module
|
||||
// being translated, and the mappings between the original and the translated
|
||||
// functions, basic blocks and values. It is practically easier to hold these
|
||||
|
@ -75,8 +76,8 @@ protected:
|
|||
private:
|
||||
LogicalResult convertFunctions();
|
||||
void convertGlobals();
|
||||
LogicalResult convertOneFunction(FuncOp func);
|
||||
void connectPHINodes(FuncOp func);
|
||||
LogicalResult convertOneFunction(LLVMFuncOp func);
|
||||
void connectPHINodes(LLVMFuncOp func);
|
||||
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
|
||||
|
||||
template <typename Range>
|
||||
|
|
|
@ -80,6 +80,7 @@ private:
|
|||
|
||||
void initializeCachedTypes() {
|
||||
const llvm::Module &module = llvmDialect->getLLVMModule();
|
||||
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
llvmPointerPointerType = llvmPointerType.getPointerTo();
|
||||
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
|
||||
|
@ -89,6 +90,8 @@ private:
|
|||
llvmDialect, module.getDataLayout().getPointerSizeInBits());
|
||||
}
|
||||
|
||||
LLVM::LLVMType getVoidType() { return llvmVoidType; }
|
||||
|
||||
LLVM::LLVMType getPointerType() { return llvmPointerType; }
|
||||
|
||||
LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
|
||||
|
@ -120,7 +123,7 @@ private:
|
|||
|
||||
void declareCudaFunctions(Location loc);
|
||||
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
|
||||
Value *generateKernelNameConstant(StringRef name, Location &loc,
|
||||
Value *generateKernelNameConstant(StringRef name, Location loc,
|
||||
OpBuilder &builder);
|
||||
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
|
||||
|
||||
|
@ -145,6 +148,7 @@ public:
|
|||
|
||||
private:
|
||||
LLVM::LLVMDialect *llvmDialect;
|
||||
LLVM::LLVMType llvmVoidType;
|
||||
LLVM::LLVMType llvmPointerType;
|
||||
LLVM::LLVMType llvmPointerPointerType;
|
||||
LLVM::LLVMType llvmInt8Type;
|
||||
|
@ -160,38 +164,41 @@ private:
|
|||
// uses void pointers. This is fine as they have the same linkage in C.
|
||||
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||
ModuleOp module = getModule();
|
||||
Builder builder(module);
|
||||
if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
|
||||
module.push_back(
|
||||
FuncOp::create(loc, cuModuleLoadName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* CUmodule *module */
|
||||
getPointerType() /* void *cubin */
|
||||
},
|
||||
getCUResultType())));
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
if (!module.lookupSymbol(cuModuleLoadName)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, cuModuleLoadName,
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getCUResultType(),
|
||||
{
|
||||
getPointerPointerType(), /* CUmodule *module */
|
||||
getPointerType() /* void *cubin */
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
|
||||
if (!module.lookupSymbol(cuModuleGetFunctionName)) {
|
||||
// The helper uses void* instead of CUDA's opaque CUmodule and
|
||||
// CUfunction.
|
||||
module.push_back(
|
||||
FuncOp::create(loc, cuModuleGetFunctionName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* void **function */
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
getCUResultType())));
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, cuModuleGetFunctionName,
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getCUResultType(),
|
||||
{
|
||||
getPointerPointerType(), /* void **function */
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
|
||||
if (!module.lookupSymbol(cuLaunchKernelName)) {
|
||||
// Other than the CUDA api, the wrappers use uintptr_t to match the
|
||||
// LLVM type if MLIR's index type, which the GPU dialect uses.
|
||||
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
|
||||
// CUstream.
|
||||
module.push_back(FuncOp::create(
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, cuLaunchKernelName,
|
||||
builder.getFunctionType(
|
||||
LLVM::LLVMType::getFunctionTy(
|
||||
getCUResultType(),
|
||||
{
|
||||
getPointerType(), /* void* f */
|
||||
getIntPtrType(), /* intptr_t gridXDim */
|
||||
|
@ -205,32 +212,31 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
getPointerPointerType(), /* void **kernelParams */
|
||||
getPointerPointerType() /* void **extra */
|
||||
},
|
||||
getCUResultType())));
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
|
||||
if (!module.lookupSymbol(cuGetStreamHelperName)) {
|
||||
// Helper function to get the current CUDA stream. Uses void* instead of
|
||||
// CUDAs opaque CUstream.
|
||||
module.push_back(FuncOp::create(
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, cuGetStreamHelperName,
|
||||
builder.getFunctionType({}, getPointerType() /* void *stream */)));
|
||||
LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
|
||||
module.push_back(
|
||||
FuncOp::create(loc, cuStreamSynchronizeName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType() /* CUstream stream */
|
||||
},
|
||||
getCUResultType())));
|
||||
if (!module.lookupSymbol(cuStreamSynchronizeName)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, cuStreamSynchronizeName,
|
||||
LLVM::LLVMType::getFunctionTy(getCUResultType(),
|
||||
getPointerType() /* CUstream stream */,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
if (!module.lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr)) {
|
||||
module.push_back(FuncOp::create(loc, kMcuMemHostRegisterPtr,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType(), /* void *ptr */
|
||||
getInt32Type() /* int32 flags*/
|
||||
},
|
||||
{})));
|
||||
if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kMcuMemHostRegisterPtr,
|
||||
LLVM::LLVMType::getFunctionTy(getVoidType(),
|
||||
{
|
||||
getPointerType(), /* void *ptr */
|
||||
getInt32Type() /* int32 flags*/
|
||||
},
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -271,7 +277,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
|||
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
|
||||
if (llvmType.isStructTy()) {
|
||||
auto registerFunc =
|
||||
getModule().lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr);
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr);
|
||||
auto zero = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(), builder.getI32IntegerAttr(0));
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
|
||||
|
@ -304,7 +310,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
|||
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
|
||||
// }
|
||||
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
|
||||
StringRef name, Location &loc, OpBuilder &builder) {
|
||||
StringRef name, Location loc, OpBuilder &builder) {
|
||||
// Make sure the trailing zero is included in the constant.
|
||||
std::vector<char> kernelName(name.begin(), name.end());
|
||||
kernelName.push_back('\0');
|
||||
|
@ -355,6 +361,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
<< "missing " << kCubinAnnotation << " attribute";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
assert(kernelModule.getName() && "expected a named module");
|
||||
SmallString<128> nameBuffer(*kernelModule.getName());
|
||||
nameBuffer.append(kCubinStorageSuffix);
|
||||
|
@ -364,7 +371,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// Emit the load module call to load the module data. Error checking is done
|
||||
// in the called helper function.
|
||||
auto cuModule = allocatePointer(builder, loc);
|
||||
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
|
||||
auto cuModuleLoad =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuModuleLoad),
|
||||
ArrayRef<Value *>{cuModule, data});
|
||||
|
@ -374,20 +382,21 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
|
||||
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
|
||||
auto cuFunction = allocatePointer(builder, loc);
|
||||
FuncOp cuModuleGetFunction =
|
||||
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
|
||||
auto cuModuleGetFunction =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
FuncOp cuGetStreamHelper =
|
||||
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
|
||||
auto cuGetStreamHelper =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
|
||||
auto cuStream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
||||
// Invoke the function with required arguments.
|
||||
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
|
||||
auto cuLaunchKernel =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
|
||||
auto cuFunctionRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||
|
@ -404,7 +413,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
paramsArray, /* kernel params */
|
||||
nullpointer /* extra */});
|
||||
// Sync on the stream to make it synchronous.
|
||||
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
|
||||
auto cuStreamSync =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuStreamSync),
|
||||
ArrayRef<Value *>(cuStream.getResult(0)));
|
||||
|
|
|
@ -381,8 +381,6 @@ public:
|
|||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<NVVM::NVVMDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
if (failed(applyPartialConversion(m, target, patterns, &converter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -95,19 +95,31 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
|
|||
}
|
||||
}
|
||||
|
||||
// Except for signatures, MLIR function types are converted into LLVM
|
||||
// pointer-to-function types.
|
||||
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
||||
SignatureConversion conversion(type.getNumInputs());
|
||||
LLVM::LLVMType converted =
|
||||
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
|
||||
return converted.getPointerTo();
|
||||
}
|
||||
|
||||
// Function types are converted to LLVM Function types by recursively converting
|
||||
// argument and result types. If MLIR Function has zero results, the LLVM
|
||||
// Function has one VoidType result. If MLIR Function has more than one result,
|
||||
// they are into an LLVM StructType in their order of appearance.
|
||||
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
||||
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
||||
FunctionType type, bool isVariadic,
|
||||
LLVMTypeConverter::SignatureConversion &result) {
|
||||
// Convert argument types one by one and check for errors.
|
||||
SmallVector<LLVM::LLVMType, 8> argTypes;
|
||||
for (auto t : type.getInputs()) {
|
||||
auto converted = convertType(t);
|
||||
if (!converted)
|
||||
for (auto &en : llvm::enumerate(type.getInputs()))
|
||||
if (failed(convertSignatureArg(en.index(), en.value(), result)))
|
||||
return {};
|
||||
argTypes.push_back(unwrap(converted));
|
||||
}
|
||||
|
||||
SmallVector<LLVM::LLVMType, 8> argTypes;
|
||||
argTypes.reserve(llvm::size(result.getConvertedTypes()));
|
||||
for (Type type : result.getConvertedTypes())
|
||||
argTypes.push_back(unwrap(type));
|
||||
|
||||
// If function does not return anything, create the void result type,
|
||||
// if it returns on element, convert it, otherwise pack the result types into
|
||||
|
@ -118,8 +130,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
|||
: unwrap(packFunctionResults(type.getResults()));
|
||||
if (!resultType)
|
||||
return {};
|
||||
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
|
||||
.getPointerTo();
|
||||
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
|
||||
}
|
||||
|
||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
||||
|
@ -249,6 +260,10 @@ public:
|
|||
&dialect, getModule().getDataLayout().getPointerSizeInBits());
|
||||
}
|
||||
|
||||
LLVM::LLVMType getVoidType() const {
|
||||
return LLVM::LLVMType::getVoidTy(&dialect);
|
||||
}
|
||||
|
||||
// Get the MLIR type wrapping the LLVM i8* type.
|
||||
LLVM::LLVMType getVoidPtrType() const {
|
||||
return LLVM::LLVMType::getInt8PtrTy(&dialect);
|
||||
|
@ -289,7 +304,16 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
FunctionType type = funcOp.getType();
|
||||
SmallVector<Type, 4> argTypes;
|
||||
// Pack the result types into a struct.
|
||||
Type packedResult;
|
||||
if (type.getNumResults() != 0)
|
||||
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
|
||||
return matchFailure();
|
||||
LLVM::LLVMType resultType = packedResult
|
||||
? packedResult.cast<LLVM::LLVMType>()
|
||||
: LLVM::LLVMType::getVoidTy(&dialect);
|
||||
|
||||
SmallVector<LLVM::LLVMType, 4> argTypes;
|
||||
argTypes.reserve(type.getNumInputs());
|
||||
SmallVector<unsigned, 4> promotedArgIndices;
|
||||
promotedArgIndices.reserve(type.getNumInputs());
|
||||
|
@ -297,14 +321,15 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
|||
// Convert the original function arguments. Struct arguments are promoted to
|
||||
// pointer to struct arguments to allow calling external functions with
|
||||
// various ABIs (e.g. compiled from C/C++ on platform X).
|
||||
TypeConverter::SignatureConversion result(type.getNumInputs());
|
||||
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
for (auto en : llvm::enumerate(type.getInputs())) {
|
||||
auto t = en.value();
|
||||
auto converted = lowering.convertType(t);
|
||||
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
|
||||
if (!converted)
|
||||
return matchFailure();
|
||||
if (t.isa<MemRefType>()) {
|
||||
converted = converted.cast<LLVM::LLVMType>().getPointerTo();
|
||||
converted = converted.getPointerTo();
|
||||
promotedArgIndices.push_back(en.index());
|
||||
}
|
||||
argTypes.push_back(converted);
|
||||
|
@ -312,21 +337,24 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
|||
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
|
||||
result.addInputs(idx, argTypes[idx]);
|
||||
|
||||
// Pack the result types into a struct.
|
||||
Type packedResult;
|
||||
if (type.getNumResults() != 0) {
|
||||
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
|
||||
return matchFailure();
|
||||
auto llvmType = LLVM::LLVMType::getFunctionTy(
|
||||
resultType, argTypes, varargsAttr && varargsAttr.getValue());
|
||||
|
||||
// Only retain those attributes that are not constructed by build.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
for (const auto &attr : funcOp.getAttrs()) {
|
||||
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
|
||||
attr.first.is(impl::getTypeAttrName()) ||
|
||||
attr.first.is("std.varargs"))
|
||||
continue;
|
||||
attributes.push_back(attr);
|
||||
}
|
||||
|
||||
// Create a new function with an updated signature.
|
||||
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
||||
// Create an LLVM funcion.
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
op->getLoc(), funcOp.getName(), llvmType, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
newFuncOp.setType(FunctionType::get(
|
||||
result.getConvertedTypes(),
|
||||
packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
|
||||
funcOp.getContext()));
|
||||
|
||||
// Tell the rewriter to convert the region signature.
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
|
||||
|
@ -627,13 +655,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
||||
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType =
|
||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
||||
mallocFunc =
|
||||
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
module.push_back(mallocFunc);
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
|
@ -792,12 +820,14 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
FuncOp freeFunc =
|
||||
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
|
||||
auto freeFunc =
|
||||
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
op->getParentOfType<ModuleOp>().push_back(freeFunc);
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "free",
|
||||
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||
|
@ -1373,9 +1403,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter->isSignatureLegal(op.getType());
|
||||
});
|
||||
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -6,5 +6,5 @@ add_llvm_library(MLIRGPU
|
|||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
|
||||
)
|
||||
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR LLVMSupport)
|
||||
target_link_libraries(MLIRGPU MLIRIR MLIRStandardOps LLVMSupport)
|
||||
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR MLIRLLVMIR LLVMSupport)
|
||||
target_link_libraries(MLIRGPU MLIRIR MLIRLLVMIR MLIRStandardOps LLVMSupport)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -37,9 +38,8 @@ using namespace mlir::gpu;
|
|||
|
||||
StringRef GPUDialect::getDialectName() { return "gpu"; }
|
||||
|
||||
bool GPUDialect::isKernel(FuncOp function) {
|
||||
UnitAttr isKernelAttr =
|
||||
function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
bool GPUDialect::isKernel(Operation *op) {
|
||||
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
return static_cast<bool>(isKernelAttr);
|
||||
}
|
||||
|
||||
|
@ -92,18 +92,25 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
|||
|
||||
// Check that `launch_func` refers to a well-formed kernel function.
|
||||
StringRef kernelName = launchOp.kernel();
|
||||
auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(kernelName);
|
||||
if (!kernelFunction)
|
||||
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
|
||||
auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
|
||||
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
|
||||
if (!kernelStdFunction && !kernelLLVMFunction)
|
||||
return launchOp.emitOpError("kernel function '")
|
||||
<< kernelName << "' is undefined";
|
||||
if (!kernelFunction.getAttrOfType<mlir::UnitAttr>(
|
||||
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
|
||||
GPUDialect::getKernelFuncAttrName()))
|
||||
return launchOp.emitOpError("kernel function is missing the '")
|
||||
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
||||
if (launchOp.getNumKernelOperands() != kernelFunction.getNumArguments())
|
||||
return launchOp.emitOpError("got ") << launchOp.getNumKernelOperands()
|
||||
<< " kernel operands but expected "
|
||||
<< kernelFunction.getNumArguments();
|
||||
|
||||
unsigned actualNumArguments = launchOp.getNumKernelOperands();
|
||||
unsigned expectedNumArguments = kernelLLVMFunction
|
||||
? kernelLLVMFunction.getNumArguments()
|
||||
: kernelStdFunction.getNumArguments();
|
||||
if (expectedNumArguments != actualNumArguments)
|
||||
return launchOp.emitOpError("got ")
|
||||
<< actualNumArguments << " kernel operands but expected "
|
||||
<< expectedNumArguments;
|
||||
|
||||
// Due to the ordering of the current impl of lowering and LLVMLowering,
|
||||
// type checks need to be temporarily disabled.
|
||||
|
|
|
@ -178,15 +178,17 @@ public:
|
|||
auto indexType = IndexType::get(op->getContext());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
||||
auto mallocFunc = module.lookupSymbol<LLVMFuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
||||
mallocFunc =
|
||||
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
module.push_back(mallocFunc);
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
mallocFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Get MLIR types for injecting element pointer.
|
||||
|
@ -257,15 +259,18 @@ public:
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
|
||||
auto freeFunc = module.lookupSymbol<LLVMFuncOp>("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
module.push_back(freeFunc);
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
freeFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "free",
|
||||
LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Emit MLIR for buffer_dealloc.
|
||||
|
|
|
@ -177,7 +177,7 @@ compileAndExecute(ModuleOp module, StringRef entryPoint,
|
|||
static Error compileAndExecuteVoidFunction(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.getBlocks().empty())
|
||||
return make_string_error("entry point not found");
|
||||
void *empty = nullptr;
|
||||
|
@ -187,22 +187,14 @@ static Error compileAndExecuteVoidFunction(
|
|||
static Error compileAndExecuteSingleFloatReturnFunction(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.isExternal()) {
|
||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.isExternal())
|
||||
return make_string_error("entry point not found");
|
||||
}
|
||||
|
||||
if (!mainFunction.getType().getInputs().empty())
|
||||
if (mainFunction.getType().getFunctionNumParams() != 0)
|
||||
return make_string_error("function inputs not supported");
|
||||
|
||||
if (mainFunction.getType().getResults().size() != 1)
|
||||
return make_string_error("only single f32 function result supported");
|
||||
|
||||
auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
|
||||
if (!t)
|
||||
return make_string_error("only single llvm.f32 function result supported");
|
||||
auto *llvmTy = t.getUnderlyingType();
|
||||
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
|
||||
if (!mainFunction.getType().getFunctionResultType().isFloatTy())
|
||||
return make_string_error("only single llvm.f32 function result supported");
|
||||
|
||||
float res;
|
||||
|
|
|
@ -25,6 +25,7 @@ add_llvm_library(MLIRTargetNVVMIR
|
|||
target_link_libraries(MLIRTargetNVVMIR
|
||||
MLIRGPU
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRNVVMIR
|
||||
MLIRTargetLLVMIRModuleTranslation
|
||||
)
|
||||
|
@ -39,6 +40,7 @@ add_llvm_library(MLIRTargetROCDLIR
|
|||
target_link_libraries(MLIRTargetROCDLIR
|
||||
MLIRGPU
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRROCDLIR
|
||||
MLIRTargetLLVMIRModuleTranslation
|
||||
)
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include "mlir/Target/NVVMIR.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
#include "mlir/Translation.h"
|
||||
|
@ -66,11 +66,13 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
|||
ModuleTranslation translation(m);
|
||||
auto llvmModule =
|
||||
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
||||
if (!llvmModule)
|
||||
return llvmModule;
|
||||
|
||||
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
||||
// function as a kernel.
|
||||
for (FuncOp func : m.getOps<FuncOp>()) {
|
||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||
if (!gpu::GPUDialect::isKernel(func))
|
||||
continue;
|
||||
|
||||
auto *llvmFunc = llvmModule->getFunction(func.getName());
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Target/ROCDLIR.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
@ -93,7 +94,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
|||
// foreach GPU kernel
|
||||
// 1. Insert AMDGPU_KERNEL calling convention.
|
||||
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
||||
for (FuncOp func : m.getOps<FuncOp>()) {
|
||||
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||
continue;
|
||||
|
||||
|
|
|
@ -39,38 +39,6 @@
|
|||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
// Convert an MLIR function type to LLVM IR. Arguments of the function must of
|
||||
// MLIR LLVM IR dialect types. Use `loc` as a location when reporting errors.
|
||||
// Return nullptr on errors.
|
||||
static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
|
||||
FunctionType type, Location loc,
|
||||
bool isVarArgs) {
|
||||
assert(type && "expected non-null type");
|
||||
if (type.getNumResults() > 1)
|
||||
return emitError(loc, "LLVM functions can only have 0 or 1 result"),
|
||||
nullptr;
|
||||
|
||||
SmallVector<llvm::Type *, 8> argTypes;
|
||||
argTypes.reserve(type.getNumInputs());
|
||||
for (auto t : type.getInputs()) {
|
||||
auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
|
||||
if (!wrappedLLVMType)
|
||||
return emitError(loc, "non-LLVM function argument type"), nullptr;
|
||||
argTypes.push_back(wrappedLLVMType.getUnderlyingType());
|
||||
}
|
||||
|
||||
if (type.getNumResults() == 0)
|
||||
return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
|
||||
isVarArgs);
|
||||
|
||||
auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
|
||||
if (!wrappedResultType)
|
||||
return emitError(loc, "non-LLVM function result"), nullptr;
|
||||
|
||||
return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
|
||||
argTypes, isVarArgs);
|
||||
}
|
||||
|
||||
// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
|
||||
// This currently supports integer, floating point, splat and dense element
|
||||
// attributes and combinations thereof. In case of error, report it to `loc`
|
||||
|
@ -362,7 +330,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
|
|||
: terminator.getSuccessorOperand(1, index);
|
||||
}
|
||||
|
||||
void ModuleTranslation::connectPHINodes(FuncOp func) {
|
||||
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
|
||||
// Skip the first block, it cannot be branched to and its arguments correspond
|
||||
// to the arguments of the LLVM function.
|
||||
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
|
||||
|
@ -393,7 +361,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
|
|||
}
|
||||
|
||||
// Sort function blocks topologically.
|
||||
static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
|
||||
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
|
||||
// For each blocks that has not been visited yet (i.e. that has no
|
||||
// predecessors), add it to the list and traverse its successors in DFS
|
||||
// preorder.
|
||||
|
@ -407,7 +375,7 @@ static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
|
|||
return blocks;
|
||||
}
|
||||
|
||||
LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
|
||||
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
|
||||
// Clear the block and value mappings, they are only relevant within one
|
||||
// function.
|
||||
blockMapping.clear();
|
||||
|
@ -460,24 +428,17 @@ LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
|
|||
LogicalResult ModuleTranslation::convertFunctions() {
|
||||
// Declare all functions first because there may be function calls that form a
|
||||
// call graph with cycles.
|
||||
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
|
||||
mlir::BoolAttr isVarArgsAttr =
|
||||
function.getAttrOfType<BoolAttr>("std.varargs");
|
||||
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
|
||||
llvm::FunctionType *functionType =
|
||||
convertFunctionType(llvmModule->getContext(), function.getType(),
|
||||
function.getLoc(), isVarArgs);
|
||||
if (!functionType)
|
||||
return failure();
|
||||
llvm::FunctionCallee llvmFuncCst =
|
||||
llvmModule->getOrInsertFunction(function.getName(), functionType);
|
||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
|
||||
function.getName(),
|
||||
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
|
||||
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
|
||||
functionMapping[function.getName()] =
|
||||
cast<llvm::Function>(llvmFuncCst.getCallee());
|
||||
}
|
||||
|
||||
// Convert functions.
|
||||
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
|
||||
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||
// Ignore external functions.
|
||||
if (function.isExternal())
|
||||
continue;
|
||||
|
|
|
@ -10,7 +10,7 @@ module attributes {gpu.container_module} {
|
|||
attributes { gpu.kernel }
|
||||
}
|
||||
|
||||
func @foo() {
|
||||
llvm.func @foo() {
|
||||
%0 = "op"() : () -> (!llvm.float)
|
||||
%1 = "op"() : () -> (!llvm<"float*">)
|
||||
%cst = constant 8 : index
|
||||
|
@ -29,7 +29,7 @@ module attributes {gpu.container_module} {
|
|||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernel_module }
|
||||
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
|
||||
|
||||
return
|
||||
llvm.return
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: attributes {gpu.kernel_module, nvvm.cubin = "CUBIN"}
|
||||
module @kernels attributes {gpu.kernel_module} {
|
||||
func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
|
||||
module @foo attributes {gpu.kernel_module} {
|
||||
llvm.func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
|
||||
// CHECK: attributes {gpu.kernel}
|
||||
attributes { gpu.kernel } {
|
||||
llvm.return
|
||||
}
|
||||
|
@ -10,15 +11,15 @@ module @kernels attributes {gpu.kernel_module} {
|
|||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.kernel_module} {
|
||||
module @bar attributes {gpu.kernel_module} {
|
||||
// CHECK: func @kernel_a
|
||||
func @kernel_a()
|
||||
llvm.func @kernel_a()
|
||||
attributes { gpu.kernel } {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK: func @kernel_b
|
||||
func @kernel_b()
|
||||
llvm.func @kernel_b()
|
||||
attributes { gpu.kernel } {
|
||||
llvm.return
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @fmuladd_test
|
||||
func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
|
||||
llvm.func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
|
||||
// CHECK: call float @llvm.fmuladd.f32.f32.f32
|
||||
"llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
|
||||
// CHECK: call <8 x float> @llvm.fmuladd.v8f32.v8f32.v8f32
|
||||
|
@ -10,7 +10,7 @@ func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x fl
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @exp_test
|
||||
func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
||||
llvm.func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
||||
// CHECK: call float @llvm.exp.f32
|
||||
"llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
|
||||
// CHECK: call <8 x float> @llvm.exp.v8f32
|
||||
|
|
|
@ -29,7 +29,7 @@ llvm.mlir.global @int_global_undef() : !llvm.i64
|
|||
//
|
||||
|
||||
// CHECK: declare i8* @malloc(i64)
|
||||
func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK: declare void @free(i8*)
|
||||
|
||||
|
||||
|
@ -41,12 +41,12 @@ func @malloc(!llvm.i64) -> !llvm<"i8*">
|
|||
// CHECK-LABEL: define void @empty() {
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT: }
|
||||
func @empty() {
|
||||
llvm.func @empty() {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @global_refs
|
||||
func @global_refs() {
|
||||
llvm.func @global_refs() {
|
||||
// Check load from globals.
|
||||
// CHECK: load i32, i32* @i32_global
|
||||
%0 = llvm.mlir.addressof @i32_global : !llvm<"i32*">
|
||||
|
@ -63,11 +63,11 @@ func @global_refs() {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: declare void @body(i64)
|
||||
func @body(!llvm.i64)
|
||||
llvm.func @body(!llvm.i64)
|
||||
|
||||
|
||||
// CHECK-LABEL: define void @simple_loop() {
|
||||
func @simple_loop() {
|
||||
llvm.func @simple_loop() {
|
||||
// CHECK: br label %[[SIMPLE_bb1:[0-9]+]]
|
||||
llvm.br ^bb1
|
||||
|
||||
|
@ -107,7 +107,7 @@ func @simple_loop() {
|
|||
// CHECK-NEXT: call void @simple_loop()
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT: }
|
||||
func @simple_caller() {
|
||||
llvm.func @simple_caller() {
|
||||
llvm.call @simple_loop() : () -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
@ -124,20 +124,20 @@ func @simple_caller() {
|
|||
// CHECK-NEXT: call void @more_imperfectly_nested_loops()
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT: }
|
||||
func @ml_caller() {
|
||||
llvm.func @ml_caller() {
|
||||
llvm.call @simple_loop() : () -> ()
|
||||
llvm.call @more_imperfectly_nested_loops() : () -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: declare i64 @body_args(i64)
|
||||
func @body_args(!llvm.i64) -> !llvm.i64
|
||||
llvm.func @body_args(!llvm.i64) -> !llvm.i64
|
||||
// CHECK-LABEL: declare i32 @other(i64, i32)
|
||||
func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
|
||||
llvm.func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
|
||||
|
||||
// CHECK-LABEL: define i32 @func_args(i32 {{%.*}}, i32 {{%.*}}) {
|
||||
// CHECK-NEXT: br label %[[ARGS_bb1:[0-9]+]]
|
||||
func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
||||
llvm.func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
||||
%0 = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
llvm.br ^bb1
|
||||
|
||||
|
@ -182,17 +182,17 @@ func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
|||
}
|
||||
|
||||
// CHECK: declare void @pre(i64)
|
||||
func @pre(!llvm.i64)
|
||||
llvm.func @pre(!llvm.i64)
|
||||
|
||||
// CHECK: declare void @body2(i64, i64)
|
||||
func @body2(!llvm.i64, !llvm.i64)
|
||||
llvm.func @body2(!llvm.i64, !llvm.i64)
|
||||
|
||||
// CHECK: declare void @post(i64)
|
||||
func @post(!llvm.i64)
|
||||
llvm.func @post(!llvm.i64)
|
||||
|
||||
// CHECK-LABEL: define void @imperfectly_nested_loops() {
|
||||
// CHECK-NEXT: br label %[[IMPER_bb1:[0-9]+]]
|
||||
func @imperfectly_nested_loops() {
|
||||
llvm.func @imperfectly_nested_loops() {
|
||||
llvm.br ^bb1
|
||||
|
||||
// CHECK: [[IMPER_bb1]]:
|
||||
|
@ -259,10 +259,10 @@ func @imperfectly_nested_loops() {
|
|||
}
|
||||
|
||||
// CHECK: declare void @mid(i64)
|
||||
func @mid(!llvm.i64)
|
||||
llvm.func @mid(!llvm.i64)
|
||||
|
||||
// CHECK: declare void @body3(i64, i64)
|
||||
func @body3(!llvm.i64, !llvm.i64)
|
||||
llvm.func @body3(!llvm.i64, !llvm.i64)
|
||||
|
||||
// A complete function transformation check.
|
||||
// CHECK-LABEL: define void @more_imperfectly_nested_loops() {
|
||||
|
@ -306,7 +306,7 @@ func @body3(!llvm.i64, !llvm.i64)
|
|||
// CHECK: 21: ; preds = %2
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT: }
|
||||
func @more_imperfectly_nested_loops() {
|
||||
llvm.func @more_imperfectly_nested_loops() {
|
||||
llvm.br ^bb1
|
||||
^bb1: // pred: ^bb0
|
||||
%0 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
|
@ -359,7 +359,7 @@ func @more_imperfectly_nested_loops() {
|
|||
//
|
||||
|
||||
// CHECK-LABEL: define void @memref_alloc()
|
||||
func @memref_alloc() {
|
||||
llvm.func @memref_alloc() {
|
||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 400)
|
||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||
// CHECK-NEXT: %{{[0-9]+}} = insertvalue { float* } undef, float* %{{[0-9]+}}, 0
|
||||
|
@ -377,10 +377,10 @@ func @memref_alloc() {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: declare i64 @get_index()
|
||||
func @get_index() -> !llvm.i64
|
||||
llvm.func @get_index() -> !llvm.i64
|
||||
|
||||
// CHECK-LABEL: define void @store_load_static()
|
||||
func @store_load_static() {
|
||||
llvm.func @store_load_static() {
|
||||
^bb0:
|
||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 40)
|
||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||
|
@ -448,7 +448,7 @@ func @store_load_static() {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define void @store_load_dynamic(i64 {{%.*}})
|
||||
func @store_load_dynamic(%arg0: !llvm.i64) {
|
||||
llvm.func @store_load_dynamic(%arg0: !llvm.i64) {
|
||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 %{{[0-9]+}})
|
||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||
|
@ -518,7 +518,7 @@ func @store_load_dynamic(%arg0: !llvm.i64) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define void @store_load_mixed(i64 {{%.*}})
|
||||
func @store_load_mixed(%arg0: !llvm.i64) {
|
||||
llvm.func @store_load_mixed(%arg0: !llvm.i64) {
|
||||
%0 = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 2, %{{[0-9]+}}
|
||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
||||
|
@ -603,7 +603,7 @@ func @store_load_mixed(%arg0: !llvm.i64) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define { float*, i64 } @memref_args_rets({ float* } {{%.*}}, { float*, i64 } {{%.*}}, { float*, i64 } {{%.*}}) {
|
||||
func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
|
||||
llvm.func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
|
||||
%0 = llvm.mlir.constant(7 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %{{[0-9]+}} = call i64 @get_index()
|
||||
%1 = llvm.call @get_index() : () -> !llvm.i64
|
||||
|
@ -657,7 +657,7 @@ func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }
|
|||
|
||||
|
||||
// CHECK-LABEL: define i64 @memref_dim({ float*, i64, i64 } {{%.*}})
|
||||
func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
||||
llvm.func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
||||
// Expecting this to create an LLVM constant.
|
||||
%0 = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %2 = extractvalue { float*, i64, i64 } %0, 1
|
||||
|
@ -678,12 +678,12 @@ func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
|||
llvm.return %6 : !llvm.i64
|
||||
}
|
||||
|
||||
func @get_i64() -> !llvm.i64
|
||||
func @get_f32() -> !llvm.float
|
||||
func @get_memref() -> !llvm<"{ float*, i64, i64 }">
|
||||
llvm.func @get_i64() -> !llvm.i64
|
||||
llvm.func @get_f32() -> !llvm.float
|
||||
llvm.func @get_memref() -> !llvm<"{ float*, i64, i64 }">
|
||||
|
||||
// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
|
||||
func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
||||
llvm.func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
||||
%0 = llvm.call @get_i64() : () -> !llvm.i64
|
||||
%1 = llvm.call @get_f32() : () -> !llvm.float
|
||||
%2 = llvm.call @get_memref() : () -> !llvm<"{ float*, i64, i64 }">
|
||||
|
@ -700,7 +700,7 @@ func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
|||
|
||||
|
||||
// CHECK-LABEL: define void @multireturn_caller() {
|
||||
func @multireturn_caller() {
|
||||
llvm.func @multireturn_caller() {
|
||||
// CHECK-NEXT: %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
|
||||
// CHECK-NEXT: [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
|
||||
// CHECK-NEXT: [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
|
||||
|
@ -734,7 +734,7 @@ func @multireturn_caller() {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define <4 x float> @vector_ops(<4 x float> {{%.*}}, <4 x i1> {{%.*}}, <4 x i64> {{%.*}}) {
|
||||
func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
|
||||
llvm.func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
|
||||
%0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: %4 = fadd <4 x float> %0, <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
|
||||
%1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
|
||||
|
@ -763,7 +763,7 @@ func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @ops
|
||||
func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
|
||||
llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
|
||||
// CHECK-NEXT: fsub float %0, %1
|
||||
%0 = llvm.fsub %arg0, %arg1 : !llvm.float
|
||||
// CHECK-NEXT: %6 = sub i32 %2, %3
|
||||
|
@ -811,7 +811,7 @@ func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm
|
|||
//
|
||||
|
||||
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) {
|
||||
func @indirect_const_call(%arg0: !llvm.i64) {
|
||||
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
|
||||
// CHECK-NEXT: call void @body(i64 %0)
|
||||
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
|
||||
llvm.call %0(%arg0) : (!llvm.i64) -> ()
|
||||
|
@ -820,7 +820,7 @@ func @indirect_const_call(%arg0: !llvm.i64) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define i32 @indirect_call(i32 (float)* {{%.*}}, float {{%.*}}) {
|
||||
func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
|
||||
llvm.func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
|
||||
// CHECK-NEXT: %3 = call i32 %0(float %1)
|
||||
%0 = llvm.call %arg0(%arg1) : (!llvm.float) -> !llvm.i32
|
||||
// CHECK-NEXT: ret i32 %3
|
||||
|
@ -833,7 +833,7 @@ func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i
|
|||
//
|
||||
|
||||
// CHECK-LABEL: define void @cond_br_arguments(i1 {{%.*}}, i1 {{%.*}}) {
|
||||
func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
||||
llvm.func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
||||
// CHECK-NEXT: br i1 %0, label %3, label %5
|
||||
llvm.cond_br %arg0, ^bb1(%arg0 : !llvm.i1), ^bb2
|
||||
|
||||
|
@ -850,15 +850,14 @@ func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: define void @llvm_noalias(float* noalias {{%*.}}) {
|
||||
func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
|
||||
llvm.func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @llvm_varargs(...)
|
||||
func @llvm_varargs()
|
||||
attributes {std.varargs = true}
|
||||
llvm.func @llvm_varargs(...)
|
||||
|
||||
func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
||||
llvm.func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
||||
// CHECK: %2 = inttoptr i32 %0 to i32*
|
||||
// CHECK-NEXT: %3 = ptrtoint i32* %2 to i32
|
||||
%1 = llvm.inttoptr %arg0 : !llvm.i32 to !llvm<"i32*">
|
||||
|
@ -866,19 +865,19 @@ func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
|||
llvm.return %2 : !llvm.i32
|
||||
}
|
||||
|
||||
func @stringconstant() -> !llvm<"i8*"> {
|
||||
llvm.func @stringconstant() -> !llvm<"i8*"> {
|
||||
%1 = llvm.mlir.constant("Hello world!") : !llvm<"i8*">
|
||||
// CHECK: ret [12 x i8] c"Hello world!"
|
||||
llvm.return %1 : !llvm<"i8*">
|
||||
}
|
||||
|
||||
func @noreach() {
|
||||
llvm.func @noreach() {
|
||||
// CHECK: unreachable
|
||||
llvm.unreachable
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @fcmp
|
||||
func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
||||
llvm.func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
||||
// CHECK: fcmp oeq float %0, %1
|
||||
// CHECK-NEXT: fcmp ogt float %0, %1
|
||||
// CHECK-NEXT: fcmp oge float %0, %1
|
||||
|
@ -911,7 +910,7 @@ func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @vect
|
||||
func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
||||
llvm.func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
||||
// CHECK-NEXT: extractelement <4 x float> {{.*}}, i32 {{.*}}
|
||||
// CHECK-NEXT: insertelement <4 x float> {{.*}}, float %2, i32 {{.*}}
|
||||
// CHECK-NEXT: shufflevector <4 x float> {{.*}}, <4 x float> {{.*}}, <5 x i32> <i32 0, i32 0, i32 0, i32 0, i32 7>
|
||||
|
@ -922,7 +921,7 @@ func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @alloca
|
||||
func @alloca(%size : !llvm.i64) {
|
||||
llvm.func @alloca(%size : !llvm.i64) {
|
||||
// CHECK: alloca
|
||||
// CHECK-NOT: align
|
||||
llvm.alloca %size x !llvm.i32 {alignment = 0} : (!llvm.i64) -> (!llvm<"i32*">)
|
||||
|
@ -932,13 +931,14 @@ func @alloca(%size : !llvm.i64) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @constants
|
||||
func @constants() -> !llvm<"<4 x float>"> {
|
||||
llvm.func @constants() -> !llvm<"<4 x float>"> {
|
||||
// CHECK: ret <4 x float> <float 4.2{{0*}}e+01, float 0.{{0*}}e+00, float 0.{{0*}}e+00, float 0.{{0*}}e+00>
|
||||
%0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : !llvm<"<4 x float>">
|
||||
llvm.return %0 : !llvm<"<4 x float>">
|
||||
}
|
||||
|
||||
func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
||||
// CHECK-LABEL: @fp_casts
|
||||
llvm.func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
||||
// CHECK: fptrunc double {{.*}} to float
|
||||
%a = llvm.fptrunc %fp2 : !llvm<"double"> to !llvm<"float">
|
||||
// CHECK: fpext float {{.*}} to double
|
||||
|
@ -948,7 +948,8 @@ func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
|||
llvm.return %c : !llvm.i16
|
||||
}
|
||||
|
||||
func @integer_extension_and_truncation(%a : !llvm.i32) {
|
||||
// CHECK-LABEL: @integer_extension_and_truncation
|
||||
llvm.func @integer_extension_and_truncation(%a : !llvm.i32) {
|
||||
// CHECK: sext i32 {{.*}} to i64
|
||||
// CHECK: zext i32 {{.*}} to i64
|
||||
// CHECK: trunc i32 {{.*}} to i16
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
|
||||
|
||||
func @nvvm_special_regs() -> !llvm.i32 {
|
||||
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
llvm.func @nvvm_special_regs() -> !llvm.i32 {
|
||||
// CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
|
||||
%2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||
|
@ -32,13 +32,13 @@ func @nvvm_special_regs() -> !llvm.i32 {
|
|||
llvm.return %1 : !llvm.i32
|
||||
}
|
||||
|
||||
func @llvm.nvvm.barrier0() {
|
||||
llvm.func @llvm.nvvm.barrier0() {
|
||||
// CHECK: call void @llvm.nvvm.barrier0()
|
||||
nvvm.barrier0
|
||||
llvm.return
|
||||
}
|
||||
|
||||
func @nvvm_shfl(
|
||||
llvm.func @nvvm_shfl(
|
||||
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
||||
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
|
||||
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||
|
@ -48,7 +48,7 @@ func @nvvm_shfl(
|
|||
llvm.return %6 : !llvm.i32
|
||||
}
|
||||
|
||||
func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
||||
llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
||||
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
|
||||
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
|
||||
llvm.return %3 : !llvm.i32
|
||||
|
@ -56,7 +56,7 @@ func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
|||
|
||||
// This function has the "kernel" attribute attached and should appear in the
|
||||
// NVVM annotations after conversion.
|
||||
func @kernel_func() attributes {gpu.kernel} {
|
||||
llvm.func @kernel_func() attributes {gpu.kernel} {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: mlir-translate -mlir-to-rocdlir %s | FileCheck %s
|
||||
|
||||
func @rocdl_special_regs() -> !llvm.i32 {
|
||||
llvm.func @rocdl_special_regs() -> !llvm.i32 {
|
||||
// CHECK-LABEL: rocdl_special_regs
|
||||
// CHECK: call i32 @llvm.amdgcn.workitem.id.x()
|
||||
%1 = rocdl.workitem.id.x : !llvm.i32
|
||||
|
@ -29,7 +29,7 @@ func @rocdl_special_regs() -> !llvm.i32 {
|
|||
llvm.return %1 : !llvm.i32
|
||||
}
|
||||
|
||||
func @kernel_func() attributes {gpu.kernel} {
|
||||
llvm.func @kernel_func() attributes {gpu.kernel} {
|
||||
// CHECK-LABEL: amdgpu_kernel void @kernel_func
|
||||
llvm.return
|
||||
}
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
// RUN: rm %T/test.o
|
||||
|
||||
// Declarations of C library functions.
|
||||
func @fabsf(!llvm.float) -> !llvm.float
|
||||
func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||
func @free(!llvm<"i8*">)
|
||||
llvm.func @fabsf(!llvm.float) -> !llvm.float
|
||||
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||
llvm.func @free(!llvm<"i8*">)
|
||||
|
||||
// Check that a simple function with a nested call works.
|
||||
func @main() -> !llvm.float {
|
||||
llvm.func @main() -> !llvm.float {
|
||||
%0 = llvm.mlir.constant(-4.200000e+02 : f32) : !llvm.float
|
||||
%1 = llvm.call @fabsf(%0) : (!llvm.float) -> !llvm.float
|
||||
llvm.return %1 : !llvm.float
|
||||
|
@ -25,13 +25,13 @@ func @main() -> !llvm.float {
|
|||
// CHECK: 4.200000e+02
|
||||
|
||||
// Helper typed functions wrapping calls to "malloc" and "free".
|
||||
func @allocation() -> !llvm<"float*"> {
|
||||
llvm.func @allocation() -> !llvm<"float*"> {
|
||||
%0 = llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
%1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*">
|
||||
%2 = llvm.bitcast %1 : !llvm<"i8*"> to !llvm<"float*">
|
||||
llvm.return %2 : !llvm<"float*">
|
||||
}
|
||||
func @deallocation(%arg0: !llvm<"float*">) {
|
||||
llvm.func @deallocation(%arg0: !llvm<"float*">) {
|
||||
%0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*">
|
||||
llvm.call @free(%0) : (!llvm<"i8*">) -> ()
|
||||
llvm.return
|
||||
|
@ -39,7 +39,7 @@ func @deallocation(%arg0: !llvm<"float*">) {
|
|||
|
||||
// Check that allocation and deallocation works, and that a custom entry point
|
||||
// works.
|
||||
func @foo() -> !llvm.float {
|
||||
llvm.func @foo() -> !llvm.float {
|
||||
%0 = llvm.call @allocation() : () -> !llvm<"float*">
|
||||
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
%2 = llvm.mlir.constant(1.234000e+03 : f32) : !llvm.float
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
|
||||
|
||||
func @print_0d() {
|
||||
%f = constant 2.00000e+00 : f32
|
||||
|
|
Loading…
Reference in New Issue