forked from OSchip/llvm-project
Use llvm.func to define functions with wrapped LLVM IR function type
This function-like operation allows one to define functions that have wrapped LLVM IR function type, in particular variadic functions. The operation was added in parallel to the existing lowering flow, this commit only switches the flow to use it. Using a custom function type makes the LLVM IR dialect type system more consistent and avoids complex conversion rules for functions that previously had to use the built-in function type instead of a wrapped LLVM IR dialect type and perform conversions during the analysis. PiperOrigin-RevId: 273910855
This commit is contained in:
parent
309b4556d0
commit
5e7959a353
|
@ -152,8 +152,6 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
|
||||||
ConversionTarget target(*module.getContext());
|
ConversionTarget target(*module.getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
target.addDynamicallyLegalOp<FuncOp>(
|
|
||||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
||||||
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|
|
@ -138,14 +138,14 @@ public:
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// Get or create the declaration of the printf function in the module.
|
// Get or create the declaration of the printf function in the module.
|
||||||
FuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
||||||
|
|
||||||
auto print = cast<toy::PrintOp>(op);
|
auto print = cast<toy::PrintOp>(op);
|
||||||
auto loc = print.getLoc();
|
auto loc = print.getLoc();
|
||||||
// We will operate on a MemRef abstraction, we use a type.cast to get one
|
// We will operate on a MemRef abstraction, we use a type.cast to get one
|
||||||
// if our operand is still a Toy array.
|
// if our operand is still a Toy array.
|
||||||
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
||||||
Type retTy = printfFunc.getType().getResult(0);
|
Type retTy = printfFunc.getType().getFunctionResultType();
|
||||||
|
|
||||||
// Create our loop nest now
|
// Create our loop nest now
|
||||||
using namespace edsc;
|
using namespace edsc;
|
||||||
|
@ -218,24 +218,23 @@ private:
|
||||||
|
|
||||||
/// Return the prototype declaration for printf in the module, create it if
|
/// Return the prototype declaration for printf in the module, create it if
|
||||||
/// necessary.
|
/// necessary.
|
||||||
FuncOp getPrintf(ModuleOp module) const {
|
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
|
||||||
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
|
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
|
||||||
if (printfFunc)
|
if (printfFunc)
|
||||||
return printfFunc;
|
return printfFunc;
|
||||||
|
|
||||||
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
||||||
Builder builder(module);
|
OpBuilder builder(module.getBodyRegion());
|
||||||
auto *dialect =
|
auto *dialect =
|
||||||
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
|
||||||
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||||
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
|
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
|
||||||
printfFunc = FuncOp::create(builder.getUnknownLoc(), "printf", printfTy);
|
/*isVarArg=*/true);
|
||||||
// It should be variadic, but we don't support it fully just yet.
|
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
|
||||||
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
|
printfTy,
|
||||||
module.push_back(printfFunc);
|
ArrayRef<NamedAttribute>());
|
||||||
return printfFunc;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -369,10 +368,10 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
|
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
|
||||||
LLVM::LLVMDialect, StandardOpsDialect>();
|
LLVM::LLVMDialect, StandardOpsDialect>();
|
||||||
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
return typeConverter.isSignatureLegal(op.getType());
|
return typeConverter.isSignatureLegal(op.getType());
|
||||||
});
|
});
|
||||||
|
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
||||||
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
|
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
|
||||||
&typeConverter))) {
|
&typeConverter))) {
|
||||||
emitError(UnknownLoc::get(getModule().getContext()),
|
emitError(UnknownLoc::get(getModule().getContext()),
|
||||||
|
|
|
@ -137,49 +137,27 @@ Examples:
|
||||||
|
|
||||||
### Function Signature Conversion
|
### Function Signature Conversion
|
||||||
|
|
||||||
MLIR function type is built into the representation, even the functions in
|
LLVM IR functions are defined by a custom operation. The function itself has a
|
||||||
dialects including a first-class function type must have the built-in MLIR
|
wrapped LLVM IR function type converted as described above. The function
|
||||||
function type. During the conversion to LLVM IR, function signatures are
|
definition operation uses MLIR syntax.
|
||||||
converted as follows:
|
|
||||||
|
|
||||||
- the outer type remains the built-in MLIR function;
|
|
||||||
- function arguments are converted individually following these rules;
|
|
||||||
- function results:
|
|
||||||
- zero-result functions remain zero-result;
|
|
||||||
- single-result functions have their result type converted according to
|
|
||||||
these rules;
|
|
||||||
- multi-result functions have a single result type of the wrapped LLVM IR
|
|
||||||
structure type with elements corresponding to the converted original
|
|
||||||
results.
|
|
||||||
|
|
||||||
Rationale: function definitions remain analyzable within MLIR without having to
|
|
||||||
abstract away the function type. In order to remain consistent with the regular
|
|
||||||
MLIR functions, we do not introduce a `void` result type since we cannot create
|
|
||||||
a value of `void` type that MLIR passes might expect to be returned from a
|
|
||||||
function.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```mlir {.mlir}
|
```mlir {.mlir}
|
||||||
// zero-ary function type with no results.
|
// zero-ary function type with no results.
|
||||||
func @foo() -> ()
|
func @foo() -> ()
|
||||||
// remains as is
|
// gets LLVM type void().
|
||||||
func @foo() -> ()
|
llvm.func @foo() -> ()
|
||||||
|
|
||||||
// unary function with one result
|
// function with one result
|
||||||
func @bar(i32) -> (i64)
|
func @bar(i32) -> (i64)
|
||||||
// has its argument and result type converted
|
// gets converted to LLVM type i64(i32).
|
||||||
func @bar(!llvm.type<"i32">) -> !llvm.type<"i64">
|
func @bar(!llvm.i32) -> !llvm.i64
|
||||||
|
|
||||||
// binary function with one result
|
// function with two results
|
||||||
func @baz(i32, f32) -> (i64)
|
|
||||||
// has its arguments handled separately
|
|
||||||
func @baz(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"i64">
|
|
||||||
|
|
||||||
// binary function with two results
|
|
||||||
func @qux(i32, f32) -> (i64, f64)
|
func @qux(i32, f32) -> (i64, f64)
|
||||||
// has its result aggregated into a structure type
|
// has its result aggregated into a structure type
|
||||||
func @qux(!llvm.type<"i32">, !llvm.type<"float">) -> !llvm.type<"{i64, double}">
|
func @qux(!llvm.i32, !llvm.float) -> !llvm.type<"{i64, double}">
|
||||||
|
|
||||||
// function-typed arguments or results in higher-order functions
|
// function-typed arguments or results in higher-order functions
|
||||||
func @quux(() -> ()) -> (() -> ())
|
func @quux(() -> ()) -> (() -> ())
|
||||||
|
|
|
@ -50,6 +50,30 @@ specific LLVM IR type.
|
||||||
All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic
|
All operations in the LLVM IR dialect have a custom form in MLIR. The mnemonic
|
||||||
of an operation is that used in LLVM IR prefixed with "`llvm.`".
|
of an operation is that used in LLVM IR prefixed with "`llvm.`".
|
||||||
|
|
||||||
|
### LLVM functions
|
||||||
|
|
||||||
|
MLIR functions are defined by an operation that is not built into the IR itself.
|
||||||
|
The LLVM IR dialect provides an `llvm.func` operation to define functions
|
||||||
|
compatible with LLVM IR. These functions have wrapped LLVM IR function type but
|
||||||
|
use MLIR syntax to express it. They are required to have exactly one result
|
||||||
|
type. LLVM function operation is intended to capture additional properties of
|
||||||
|
LLVM functions, such as linkage and calling convention, that may be modeled
|
||||||
|
differently by the built-in MLIR function.
|
||||||
|
|
||||||
|
```mlir {.mlir}
|
||||||
|
// The type of @bar is !llvm<"i64 (i64)">
|
||||||
|
llvm.func @bar(%arg0: !llvm.i64) -> !llvm.i64 {
|
||||||
|
llvm.return %arg0 : !llvm.i64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type type of @foo is !llvm<"void (i64)">
|
||||||
|
// !llvm.void type is omitted
|
||||||
|
llvm.func @foo(%arg0: !llvm.i64) {
|
||||||
|
llvm.return
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
### LLVM IR operations
|
### LLVM IR operations
|
||||||
|
|
||||||
The following operations are currently supported. The semantics of these
|
The following operations are currently supported. The semantics of these
|
||||||
|
|
|
@ -25,15 +25,12 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
class FuncOp;
|
|
||||||
class Location;
|
class Location;
|
||||||
class ModuleOp;
|
class ModuleOp;
|
||||||
class OpBuilder;
|
|
||||||
class Value;
|
|
||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
class LLVMDialect;
|
class LLVMDialect;
|
||||||
}
|
} // namespace LLVM
|
||||||
|
|
||||||
template <typename T> class OpPassBase;
|
template <typename T> class OpPassBase;
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,12 @@ public:
|
||||||
/// non-standard or non-builtin types.
|
/// non-standard or non-builtin types.
|
||||||
Type convertType(Type t) override;
|
Type convertType(Type t) override;
|
||||||
|
|
||||||
|
/// Convert a function type. The arguments and results are converted one by
|
||||||
|
/// one and results are packed into a wrapped LLVM IR structure type. `result`
|
||||||
|
/// is populated with argument mapping.
|
||||||
|
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
|
||||||
|
SignatureConversion &result);
|
||||||
|
|
||||||
/// Convert a non-empty list of types to be returned from a function into a
|
/// Convert a non-empty list of types to be returned from a function into a
|
||||||
/// supported LLVM IR type. In particular, if more than one values is
|
/// supported LLVM IR type. In particular, if more than one values is
|
||||||
/// returned, create an LLVM IR structure type with elements that correspond
|
/// returned, create an LLVM IR structure type with elements that correspond
|
||||||
|
|
|
@ -55,7 +55,7 @@ public:
|
||||||
|
|
||||||
/// Returns whether the given function is a kernel function, i.e., has the
|
/// Returns whether the given function is a kernel function, i.e., has the
|
||||||
/// 'gpu.kernel' attribute.
|
/// 'gpu.kernel' attribute.
|
||||||
static bool isKernel(FuncOp function);
|
static bool isKernel(Operation *op);
|
||||||
|
|
||||||
LogicalResult verifyOperationAttribute(Operation *op,
|
LogicalResult verifyOperationAttribute(Operation *op,
|
||||||
NamedAttribute attr) override;
|
NamedAttribute attr) override;
|
||||||
|
|
|
@ -64,6 +64,9 @@ public:
|
||||||
LLVMDialect &getDialect();
|
LLVMDialect &getDialect();
|
||||||
llvm::Type *getUnderlyingType() const;
|
llvm::Type *getUnderlyingType() const;
|
||||||
|
|
||||||
|
/// Utilities to identify types.
|
||||||
|
bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
|
||||||
|
|
||||||
/// Array type utilities.
|
/// Array type utilities.
|
||||||
LLVMType getArrayElementType();
|
LLVMType getArrayElementType();
|
||||||
unsigned getArrayNumElements();
|
unsigned getArrayNumElements();
|
||||||
|
|
|
@ -525,11 +525,15 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
|
OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
|
||||||
"LLVMType type, ArrayRef<NamedAttribute> attrs, "
|
"LLVMType type, ArrayRef<NamedAttribute> attrs = {}, "
|
||||||
"ArrayRef<NamedAttributeList> argAttrs = {}">
|
"ArrayRef<NamedAttributeList> argAttrs = {}">
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
// Add an entry block to an empty function, and set up the block arguments
|
||||||
|
// to match the signature of the function.
|
||||||
|
Block *addEntryBlock();
|
||||||
|
|
||||||
LLVMType getType() {
|
LLVMType getType() {
|
||||||
return getAttrOfType<TypeAttr>(getTypeAttrName())
|
return getAttrOfType<TypeAttr>(getTypeAttrName())
|
||||||
.getValue().cast<LLVMType>();
|
.getValue().cast<LLVMType>();
|
||||||
|
|
|
@ -34,13 +34,14 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Attribute;
|
class Attribute;
|
||||||
class FuncOp;
|
|
||||||
class Location;
|
class Location;
|
||||||
class ModuleOp;
|
class ModuleOp;
|
||||||
class Operation;
|
class Operation;
|
||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
|
|
||||||
|
class LLVMFuncOp;
|
||||||
|
|
||||||
// Implementation class for module translation. Holds a reference to the module
|
// Implementation class for module translation. Holds a reference to the module
|
||||||
// being translated, and the mappings between the original and the translated
|
// being translated, and the mappings between the original and the translated
|
||||||
// functions, basic blocks and values. It is practically easier to hold these
|
// functions, basic blocks and values. It is practically easier to hold these
|
||||||
|
@ -75,8 +76,8 @@ protected:
|
||||||
private:
|
private:
|
||||||
LogicalResult convertFunctions();
|
LogicalResult convertFunctions();
|
||||||
void convertGlobals();
|
void convertGlobals();
|
||||||
LogicalResult convertOneFunction(FuncOp func);
|
LogicalResult convertOneFunction(LLVMFuncOp func);
|
||||||
void connectPHINodes(FuncOp func);
|
void connectPHINodes(LLVMFuncOp func);
|
||||||
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
|
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
|
||||||
|
|
||||||
template <typename Range>
|
template <typename Range>
|
||||||
|
|
|
@ -80,6 +80,7 @@ private:
|
||||||
|
|
||||||
void initializeCachedTypes() {
|
void initializeCachedTypes() {
|
||||||
const llvm::Module &module = llvmDialect->getLLVMModule();
|
const llvm::Module &module = llvmDialect->getLLVMModule();
|
||||||
|
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
llvmPointerPointerType = llvmPointerType.getPointerTo();
|
llvmPointerPointerType = llvmPointerType.getPointerTo();
|
||||||
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
|
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
|
||||||
|
@ -89,6 +90,8 @@ private:
|
||||||
llvmDialect, module.getDataLayout().getPointerSizeInBits());
|
llvmDialect, module.getDataLayout().getPointerSizeInBits());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLVM::LLVMType getVoidType() { return llvmVoidType; }
|
||||||
|
|
||||||
LLVM::LLVMType getPointerType() { return llvmPointerType; }
|
LLVM::LLVMType getPointerType() { return llvmPointerType; }
|
||||||
|
|
||||||
LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
|
LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
|
||||||
|
@ -120,7 +123,7 @@ private:
|
||||||
|
|
||||||
void declareCudaFunctions(Location loc);
|
void declareCudaFunctions(Location loc);
|
||||||
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
|
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
|
||||||
Value *generateKernelNameConstant(StringRef name, Location &loc,
|
Value *generateKernelNameConstant(StringRef name, Location loc,
|
||||||
OpBuilder &builder);
|
OpBuilder &builder);
|
||||||
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
|
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
|
||||||
|
|
||||||
|
@ -145,6 +148,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLVM::LLVMDialect *llvmDialect;
|
LLVM::LLVMDialect *llvmDialect;
|
||||||
|
LLVM::LLVMType llvmVoidType;
|
||||||
LLVM::LLVMType llvmPointerType;
|
LLVM::LLVMType llvmPointerType;
|
||||||
LLVM::LLVMType llvmPointerPointerType;
|
LLVM::LLVMType llvmPointerPointerType;
|
||||||
LLVM::LLVMType llvmInt8Type;
|
LLVM::LLVMType llvmInt8Type;
|
||||||
|
@ -160,38 +164,41 @@ private:
|
||||||
// uses void pointers. This is fine as they have the same linkage in C.
|
// uses void pointers. This is fine as they have the same linkage in C.
|
||||||
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||||
ModuleOp module = getModule();
|
ModuleOp module = getModule();
|
||||||
Builder builder(module);
|
OpBuilder builder(module.getBody()->getTerminator());
|
||||||
if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
|
if (!module.lookupSymbol(cuModuleLoadName)) {
|
||||||
module.push_back(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
FuncOp::create(loc, cuModuleLoadName,
|
loc, cuModuleLoadName,
|
||||||
builder.getFunctionType(
|
LLVM::LLVMType::getFunctionTy(
|
||||||
{
|
getCUResultType(),
|
||||||
getPointerPointerType(), /* CUmodule *module */
|
{
|
||||||
getPointerType() /* void *cubin */
|
getPointerPointerType(), /* CUmodule *module */
|
||||||
},
|
getPointerType() /* void *cubin */
|
||||||
getCUResultType())));
|
},
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
|
if (!module.lookupSymbol(cuModuleGetFunctionName)) {
|
||||||
// The helper uses void* instead of CUDA's opaque CUmodule and
|
// The helper uses void* instead of CUDA's opaque CUmodule and
|
||||||
// CUfunction.
|
// CUfunction.
|
||||||
module.push_back(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
FuncOp::create(loc, cuModuleGetFunctionName,
|
loc, cuModuleGetFunctionName,
|
||||||
builder.getFunctionType(
|
LLVM::LLVMType::getFunctionTy(
|
||||||
{
|
getCUResultType(),
|
||||||
getPointerPointerType(), /* void **function */
|
{
|
||||||
getPointerType(), /* void *module */
|
getPointerPointerType(), /* void **function */
|
||||||
getPointerType() /* char *name */
|
getPointerType(), /* void *module */
|
||||||
},
|
getPointerType() /* char *name */
|
||||||
getCUResultType())));
|
},
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
|
if (!module.lookupSymbol(cuLaunchKernelName)) {
|
||||||
// Other than the CUDA api, the wrappers use uintptr_t to match the
|
// Other than the CUDA api, the wrappers use uintptr_t to match the
|
||||||
// LLVM type if MLIR's index type, which the GPU dialect uses.
|
// LLVM type if MLIR's index type, which the GPU dialect uses.
|
||||||
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
|
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
|
||||||
// CUstream.
|
// CUstream.
|
||||||
module.push_back(FuncOp::create(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
loc, cuLaunchKernelName,
|
loc, cuLaunchKernelName,
|
||||||
builder.getFunctionType(
|
LLVM::LLVMType::getFunctionTy(
|
||||||
|
getCUResultType(),
|
||||||
{
|
{
|
||||||
getPointerType(), /* void* f */
|
getPointerType(), /* void* f */
|
||||||
getIntPtrType(), /* intptr_t gridXDim */
|
getIntPtrType(), /* intptr_t gridXDim */
|
||||||
|
@ -205,32 +212,31 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||||
getPointerPointerType(), /* void **kernelParams */
|
getPointerPointerType(), /* void **kernelParams */
|
||||||
getPointerPointerType() /* void **extra */
|
getPointerPointerType() /* void **extra */
|
||||||
},
|
},
|
||||||
getCUResultType())));
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
|
if (!module.lookupSymbol(cuGetStreamHelperName)) {
|
||||||
// Helper function to get the current CUDA stream. Uses void* instead of
|
// Helper function to get the current CUDA stream. Uses void* instead of
|
||||||
// CUDAs opaque CUstream.
|
// CUDAs opaque CUstream.
|
||||||
module.push_back(FuncOp::create(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
loc, cuGetStreamHelperName,
|
loc, cuGetStreamHelperName,
|
||||||
builder.getFunctionType({}, getPointerType() /* void *stream */)));
|
LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
|
if (!module.lookupSymbol(cuStreamSynchronizeName)) {
|
||||||
module.push_back(
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
FuncOp::create(loc, cuStreamSynchronizeName,
|
loc, cuStreamSynchronizeName,
|
||||||
builder.getFunctionType(
|
LLVM::LLVMType::getFunctionTy(getCUResultType(),
|
||||||
{
|
getPointerType() /* CUstream stream */,
|
||||||
getPointerType() /* CUstream stream */
|
/*isVarArg=*/false));
|
||||||
},
|
|
||||||
getCUResultType())));
|
|
||||||
}
|
}
|
||||||
if (!module.lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr)) {
|
if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) {
|
||||||
module.push_back(FuncOp::create(loc, kMcuMemHostRegisterPtr,
|
builder.create<LLVM::LLVMFuncOp>(
|
||||||
builder.getFunctionType(
|
loc, kMcuMemHostRegisterPtr,
|
||||||
{
|
LLVM::LLVMType::getFunctionTy(getVoidType(),
|
||||||
getPointerType(), /* void *ptr */
|
{
|
||||||
getInt32Type() /* int32 flags*/
|
getPointerType(), /* void *ptr */
|
||||||
},
|
getInt32Type() /* int32 flags*/
|
||||||
{})));
|
},
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,7 +277,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
||||||
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
|
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
|
||||||
if (llvmType.isStructTy()) {
|
if (llvmType.isStructTy()) {
|
||||||
auto registerFunc =
|
auto registerFunc =
|
||||||
getModule().lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr);
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr);
|
||||||
auto zero = builder.create<LLVM::ConstantOp>(
|
auto zero = builder.create<LLVM::ConstantOp>(
|
||||||
loc, getInt32Type(), builder.getI32IntegerAttr(0));
|
loc, getInt32Type(), builder.getI32IntegerAttr(0));
|
||||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
|
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
|
||||||
|
@ -304,7 +310,7 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
||||||
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
|
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
|
||||||
// }
|
// }
|
||||||
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
|
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
|
||||||
StringRef name, Location &loc, OpBuilder &builder) {
|
StringRef name, Location loc, OpBuilder &builder) {
|
||||||
// Make sure the trailing zero is included in the constant.
|
// Make sure the trailing zero is included in the constant.
|
||||||
std::vector<char> kernelName(name.begin(), name.end());
|
std::vector<char> kernelName(name.begin(), name.end());
|
||||||
kernelName.push_back('\0');
|
kernelName.push_back('\0');
|
||||||
|
@ -355,6 +361,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
||||||
<< "missing " << kCubinAnnotation << " attribute";
|
<< "missing " << kCubinAnnotation << " attribute";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(kernelModule.getName() && "expected a named module");
|
assert(kernelModule.getName() && "expected a named module");
|
||||||
SmallString<128> nameBuffer(*kernelModule.getName());
|
SmallString<128> nameBuffer(*kernelModule.getName());
|
||||||
nameBuffer.append(kCubinStorageSuffix);
|
nameBuffer.append(kCubinStorageSuffix);
|
||||||
|
@ -364,7 +371,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
||||||
// Emit the load module call to load the module data. Error checking is done
|
// Emit the load module call to load the module data. Error checking is done
|
||||||
// in the called helper function.
|
// in the called helper function.
|
||||||
auto cuModule = allocatePointer(builder, loc);
|
auto cuModule = allocatePointer(builder, loc);
|
||||||
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
|
auto cuModuleLoad =
|
||||||
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
|
||||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||||
builder.getSymbolRefAttr(cuModuleLoad),
|
builder.getSymbolRefAttr(cuModuleLoad),
|
||||||
ArrayRef<Value *>{cuModule, data});
|
ArrayRef<Value *>{cuModule, data});
|
||||||
|
@ -374,20 +382,21 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
||||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
|
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
|
||||||
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
|
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
|
||||||
auto cuFunction = allocatePointer(builder, loc);
|
auto cuFunction = allocatePointer(builder, loc);
|
||||||
FuncOp cuModuleGetFunction =
|
auto cuModuleGetFunction =
|
||||||
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
|
||||||
builder.create<LLVM::CallOp>(
|
builder.create<LLVM::CallOp>(
|
||||||
loc, ArrayRef<Type>{getCUResultType()},
|
loc, ArrayRef<Type>{getCUResultType()},
|
||||||
builder.getSymbolRefAttr(cuModuleGetFunction),
|
builder.getSymbolRefAttr(cuModuleGetFunction),
|
||||||
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
||||||
// Grab the global stream needed for execution.
|
// Grab the global stream needed for execution.
|
||||||
FuncOp cuGetStreamHelper =
|
auto cuGetStreamHelper =
|
||||||
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
|
||||||
auto cuStream = builder.create<LLVM::CallOp>(
|
auto cuStream = builder.create<LLVM::CallOp>(
|
||||||
loc, ArrayRef<Type>{getPointerType()},
|
loc, ArrayRef<Type>{getPointerType()},
|
||||||
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
||||||
// Invoke the function with required arguments.
|
// Invoke the function with required arguments.
|
||||||
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
|
auto cuLaunchKernel =
|
||||||
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
|
||||||
auto cuFunctionRef =
|
auto cuFunctionRef =
|
||||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
||||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||||
|
@ -404,7 +413,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
||||||
paramsArray, /* kernel params */
|
paramsArray, /* kernel params */
|
||||||
nullpointer /* extra */});
|
nullpointer /* extra */});
|
||||||
// Sync on the stream to make it synchronous.
|
// Sync on the stream to make it synchronous.
|
||||||
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
|
auto cuStreamSync =
|
||||||
|
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
|
||||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||||
builder.getSymbolRefAttr(cuStreamSync),
|
builder.getSymbolRefAttr(cuStreamSync),
|
||||||
ArrayRef<Value *>(cuStream.getResult(0)));
|
ArrayRef<Value *>(cuStream.getResult(0)));
|
||||||
|
|
|
@ -381,8 +381,6 @@ public:
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addLegalDialect<NVVM::NVVMDialect>();
|
target.addLegalDialect<NVVM::NVVMDialect>();
|
||||||
target.addDynamicallyLegalOp<FuncOp>(
|
|
||||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
||||||
if (failed(applyPartialConversion(m, target, patterns, &converter)))
|
if (failed(applyPartialConversion(m, target, patterns, &converter)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,19 +95,31 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Except for signatures, MLIR function types are converted into LLVM
|
||||||
|
// pointer-to-function types.
|
||||||
|
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
||||||
|
SignatureConversion conversion(type.getNumInputs());
|
||||||
|
LLVM::LLVMType converted =
|
||||||
|
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
|
||||||
|
return converted.getPointerTo();
|
||||||
|
}
|
||||||
|
|
||||||
// Function types are converted to LLVM Function types by recursively converting
|
// Function types are converted to LLVM Function types by recursively converting
|
||||||
// argument and result types. If MLIR Function has zero results, the LLVM
|
// argument and result types. If MLIR Function has zero results, the LLVM
|
||||||
// Function has one VoidType result. If MLIR Function has more than one result,
|
// Function has one VoidType result. If MLIR Function has more than one result,
|
||||||
// they are into an LLVM StructType in their order of appearance.
|
// they are into an LLVM StructType in their order of appearance.
|
||||||
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
||||||
|
FunctionType type, bool isVariadic,
|
||||||
|
LLVMTypeConverter::SignatureConversion &result) {
|
||||||
// Convert argument types one by one and check for errors.
|
// Convert argument types one by one and check for errors.
|
||||||
SmallVector<LLVM::LLVMType, 8> argTypes;
|
for (auto &en : llvm::enumerate(type.getInputs()))
|
||||||
for (auto t : type.getInputs()) {
|
if (failed(convertSignatureArg(en.index(), en.value(), result)))
|
||||||
auto converted = convertType(t);
|
|
||||||
if (!converted)
|
|
||||||
return {};
|
return {};
|
||||||
argTypes.push_back(unwrap(converted));
|
|
||||||
}
|
SmallVector<LLVM::LLVMType, 8> argTypes;
|
||||||
|
argTypes.reserve(llvm::size(result.getConvertedTypes()));
|
||||||
|
for (Type type : result.getConvertedTypes())
|
||||||
|
argTypes.push_back(unwrap(type));
|
||||||
|
|
||||||
// If function does not return anything, create the void result type,
|
// If function does not return anything, create the void result type,
|
||||||
// if it returns on element, convert it, otherwise pack the result types into
|
// if it returns on element, convert it, otherwise pack the result types into
|
||||||
|
@ -118,8 +130,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
||||||
: unwrap(packFunctionResults(type.getResults()));
|
: unwrap(packFunctionResults(type.getResults()));
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return {};
|
return {};
|
||||||
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
|
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
|
||||||
.getPointerTo();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
||||||
|
@ -249,6 +260,10 @@ public:
|
||||||
&dialect, getModule().getDataLayout().getPointerSizeInBits());
|
&dialect, getModule().getDataLayout().getPointerSizeInBits());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLVM::LLVMType getVoidType() const {
|
||||||
|
return LLVM::LLVMType::getVoidTy(&dialect);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the MLIR type wrapping the LLVM i8* type.
|
// Get the MLIR type wrapping the LLVM i8* type.
|
||||||
LLVM::LLVMType getVoidPtrType() const {
|
LLVM::LLVMType getVoidPtrType() const {
|
||||||
return LLVM::LLVMType::getInt8PtrTy(&dialect);
|
return LLVM::LLVMType::getInt8PtrTy(&dialect);
|
||||||
|
@ -289,7 +304,16 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto funcOp = cast<FuncOp>(op);
|
auto funcOp = cast<FuncOp>(op);
|
||||||
FunctionType type = funcOp.getType();
|
FunctionType type = funcOp.getType();
|
||||||
SmallVector<Type, 4> argTypes;
|
// Pack the result types into a struct.
|
||||||
|
Type packedResult;
|
||||||
|
if (type.getNumResults() != 0)
|
||||||
|
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
|
||||||
|
return matchFailure();
|
||||||
|
LLVM::LLVMType resultType = packedResult
|
||||||
|
? packedResult.cast<LLVM::LLVMType>()
|
||||||
|
: LLVM::LLVMType::getVoidTy(&dialect);
|
||||||
|
|
||||||
|
SmallVector<LLVM::LLVMType, 4> argTypes;
|
||||||
argTypes.reserve(type.getNumInputs());
|
argTypes.reserve(type.getNumInputs());
|
||||||
SmallVector<unsigned, 4> promotedArgIndices;
|
SmallVector<unsigned, 4> promotedArgIndices;
|
||||||
promotedArgIndices.reserve(type.getNumInputs());
|
promotedArgIndices.reserve(type.getNumInputs());
|
||||||
|
@ -297,14 +321,15 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
||||||
// Convert the original function arguments. Struct arguments are promoted to
|
// Convert the original function arguments. Struct arguments are promoted to
|
||||||
// pointer to struct arguments to allow calling external functions with
|
// pointer to struct arguments to allow calling external functions with
|
||||||
// various ABIs (e.g. compiled from C/C++ on platform X).
|
// various ABIs (e.g. compiled from C/C++ on platform X).
|
||||||
TypeConverter::SignatureConversion result(type.getNumInputs());
|
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
|
||||||
|
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||||
for (auto en : llvm::enumerate(type.getInputs())) {
|
for (auto en : llvm::enumerate(type.getInputs())) {
|
||||||
auto t = en.value();
|
auto t = en.value();
|
||||||
auto converted = lowering.convertType(t);
|
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
|
||||||
if (!converted)
|
if (!converted)
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
if (t.isa<MemRefType>()) {
|
if (t.isa<MemRefType>()) {
|
||||||
converted = converted.cast<LLVM::LLVMType>().getPointerTo();
|
converted = converted.getPointerTo();
|
||||||
promotedArgIndices.push_back(en.index());
|
promotedArgIndices.push_back(en.index());
|
||||||
}
|
}
|
||||||
argTypes.push_back(converted);
|
argTypes.push_back(converted);
|
||||||
|
@ -312,21 +337,24 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
||||||
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
|
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
|
||||||
result.addInputs(idx, argTypes[idx]);
|
result.addInputs(idx, argTypes[idx]);
|
||||||
|
|
||||||
// Pack the result types into a struct.
|
auto llvmType = LLVM::LLVMType::getFunctionTy(
|
||||||
Type packedResult;
|
resultType, argTypes, varargsAttr && varargsAttr.getValue());
|
||||||
if (type.getNumResults() != 0) {
|
|
||||||
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
|
// Only retain those attributes that are not constructed by build.
|
||||||
return matchFailure();
|
SmallVector<NamedAttribute, 4> attributes;
|
||||||
|
for (const auto &attr : funcOp.getAttrs()) {
|
||||||
|
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
|
||||||
|
attr.first.is(impl::getTypeAttrName()) ||
|
||||||
|
attr.first.is("std.varargs"))
|
||||||
|
continue;
|
||||||
|
attributes.push_back(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new function with an updated signature.
|
// Create an LLVM funcion.
|
||||||
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
|
op->getLoc(), funcOp.getName(), llvmType, attributes);
|
||||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||||
newFuncOp.end());
|
newFuncOp.end());
|
||||||
newFuncOp.setType(FunctionType::get(
|
|
||||||
result.getConvertedTypes(),
|
|
||||||
packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
|
|
||||||
funcOp.getContext()));
|
|
||||||
|
|
||||||
// Tell the rewriter to convert the region signature.
|
// Tell the rewriter to convert the region signature.
|
||||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
|
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
|
||||||
|
@ -627,13 +655,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
|
|
||||||
// Insert the `malloc` declaration if it is not already present.
|
// Insert the `malloc` declaration if it is not already present.
|
||||||
auto module = op->getParentOfType<ModuleOp>();
|
auto module = op->getParentOfType<ModuleOp>();
|
||||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||||
if (!mallocFunc) {
|
if (!mallocFunc) {
|
||||||
auto mallocType =
|
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||||
mallocFunc =
|
rewriter.getUnknownLoc(), "malloc",
|
||||||
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
|
||||||
module.push_back(mallocFunc);
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||||
|
@ -792,12 +820,14 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||||
OperandAdaptor<DeallocOp> transformed(operands);
|
OperandAdaptor<DeallocOp> transformed(operands);
|
||||||
|
|
||||||
// Insert the `free` declaration if it is not already present.
|
// Insert the `free` declaration if it is not already present.
|
||||||
FuncOp freeFunc =
|
auto freeFunc =
|
||||||
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
|
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
|
||||||
if (!freeFunc) {
|
if (!freeFunc) {
|
||||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||||
op->getParentOfType<ModuleOp>().push_back(freeFunc);
|
rewriter.getUnknownLoc(), "free",
|
||||||
|
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||||
|
@ -1373,9 +1403,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
|
||||||
return typeConverter->isSignatureLegal(op.getType());
|
|
||||||
});
|
|
||||||
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
|
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,5 +6,5 @@ add_llvm_library(MLIRGPU
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
|
||||||
)
|
)
|
||||||
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR LLVMSupport)
|
add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR MLIRLLVMIR LLVMSupport)
|
||||||
target_link_libraries(MLIRGPU MLIRIR MLIRStandardOps LLVMSupport)
|
target_link_libraries(MLIRGPU MLIRIR MLIRLLVMIR MLIRStandardOps LLVMSupport)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
|
@ -37,9 +38,8 @@ using namespace mlir::gpu;
|
||||||
|
|
||||||
StringRef GPUDialect::getDialectName() { return "gpu"; }
|
StringRef GPUDialect::getDialectName() { return "gpu"; }
|
||||||
|
|
||||||
bool GPUDialect::isKernel(FuncOp function) {
|
bool GPUDialect::isKernel(Operation *op) {
|
||||||
UnitAttr isKernelAttr =
|
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||||
function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
|
||||||
return static_cast<bool>(isKernelAttr);
|
return static_cast<bool>(isKernelAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,18 +92,25 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
||||||
|
|
||||||
// Check that `launch_func` refers to a well-formed kernel function.
|
// Check that `launch_func` refers to a well-formed kernel function.
|
||||||
StringRef kernelName = launchOp.kernel();
|
StringRef kernelName = launchOp.kernel();
|
||||||
auto kernelFunction = kernelModule.lookupSymbol<FuncOp>(kernelName);
|
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
|
||||||
if (!kernelFunction)
|
auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
|
||||||
|
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
|
||||||
|
if (!kernelStdFunction && !kernelLLVMFunction)
|
||||||
return launchOp.emitOpError("kernel function '")
|
return launchOp.emitOpError("kernel function '")
|
||||||
<< kernelName << "' is undefined";
|
<< kernelName << "' is undefined";
|
||||||
if (!kernelFunction.getAttrOfType<mlir::UnitAttr>(
|
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
|
||||||
GPUDialect::getKernelFuncAttrName()))
|
GPUDialect::getKernelFuncAttrName()))
|
||||||
return launchOp.emitOpError("kernel function is missing the '")
|
return launchOp.emitOpError("kernel function is missing the '")
|
||||||
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
||||||
if (launchOp.getNumKernelOperands() != kernelFunction.getNumArguments())
|
|
||||||
return launchOp.emitOpError("got ") << launchOp.getNumKernelOperands()
|
unsigned actualNumArguments = launchOp.getNumKernelOperands();
|
||||||
<< " kernel operands but expected "
|
unsigned expectedNumArguments = kernelLLVMFunction
|
||||||
<< kernelFunction.getNumArguments();
|
? kernelLLVMFunction.getNumArguments()
|
||||||
|
: kernelStdFunction.getNumArguments();
|
||||||
|
if (expectedNumArguments != actualNumArguments)
|
||||||
|
return launchOp.emitOpError("got ")
|
||||||
|
<< actualNumArguments << " kernel operands but expected "
|
||||||
|
<< expectedNumArguments;
|
||||||
|
|
||||||
// Due to the ordering of the current impl of lowering and LLVMLowering,
|
// Due to the ordering of the current impl of lowering and LLVMLowering,
|
||||||
// type checks need to be temporarily disabled.
|
// type checks need to be temporarily disabled.
|
||||||
|
|
|
@ -178,15 +178,17 @@ public:
|
||||||
auto indexType = IndexType::get(op->getContext());
|
auto indexType = IndexType::get(op->getContext());
|
||||||
auto voidPtrTy =
|
auto voidPtrTy =
|
||||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
|
||||||
|
.cast<LLVM::LLVMType>();
|
||||||
// Insert the `malloc` declaration if it is not already present.
|
// Insert the `malloc` declaration if it is not already present.
|
||||||
auto module = op->getParentOfType<ModuleOp>();
|
auto module = op->getParentOfType<ModuleOp>();
|
||||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
auto mallocFunc = module.lookupSymbol<LLVMFuncOp>("malloc");
|
||||||
if (!mallocFunc) {
|
if (!mallocFunc) {
|
||||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||||
mallocFunc =
|
mallocFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||||
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
rewriter.getUnknownLoc(), "malloc",
|
||||||
module.push_back(mallocFunc);
|
LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty,
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get MLIR types for injecting element pointer.
|
// Get MLIR types for injecting element pointer.
|
||||||
|
@ -257,15 +259,18 @@ public:
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect());
|
||||||
auto voidPtrTy =
|
auto voidPtrTy =
|
||||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||||
// Insert the `free` declaration if it is not already present.
|
// Insert the `free` declaration if it is not already present.
|
||||||
auto module = op->getParentOfType<ModuleOp>();
|
auto module = op->getParentOfType<ModuleOp>();
|
||||||
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
|
auto freeFunc = module.lookupSymbol<LLVMFuncOp>("free");
|
||||||
if (!freeFunc) {
|
if (!freeFunc) {
|
||||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
freeFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||||
module.push_back(freeFunc);
|
rewriter.getUnknownLoc(), "free",
|
||||||
|
LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy,
|
||||||
|
/*isVarArg=*/false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit MLIR for buffer_dealloc.
|
// Emit MLIR for buffer_dealloc.
|
||||||
|
|
|
@ -177,7 +177,7 @@ compileAndExecute(ModuleOp module, StringRef entryPoint,
|
||||||
static Error compileAndExecuteVoidFunction(
|
static Error compileAndExecuteVoidFunction(
|
||||||
ModuleOp module, StringRef entryPoint,
|
ModuleOp module, StringRef entryPoint,
|
||||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||||
if (!mainFunction || mainFunction.getBlocks().empty())
|
if (!mainFunction || mainFunction.getBlocks().empty())
|
||||||
return make_string_error("entry point not found");
|
return make_string_error("entry point not found");
|
||||||
void *empty = nullptr;
|
void *empty = nullptr;
|
||||||
|
@ -187,22 +187,14 @@ static Error compileAndExecuteVoidFunction(
|
||||||
static Error compileAndExecuteSingleFloatReturnFunction(
|
static Error compileAndExecuteSingleFloatReturnFunction(
|
||||||
ModuleOp module, StringRef entryPoint,
|
ModuleOp module, StringRef entryPoint,
|
||||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||||
if (!mainFunction || mainFunction.isExternal()) {
|
if (!mainFunction || mainFunction.isExternal())
|
||||||
return make_string_error("entry point not found");
|
return make_string_error("entry point not found");
|
||||||
}
|
|
||||||
|
|
||||||
if (!mainFunction.getType().getInputs().empty())
|
if (mainFunction.getType().getFunctionNumParams() != 0)
|
||||||
return make_string_error("function inputs not supported");
|
return make_string_error("function inputs not supported");
|
||||||
|
|
||||||
if (mainFunction.getType().getResults().size() != 1)
|
if (!mainFunction.getType().getFunctionResultType().isFloatTy())
|
||||||
return make_string_error("only single f32 function result supported");
|
|
||||||
|
|
||||||
auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
|
|
||||||
if (!t)
|
|
||||||
return make_string_error("only single llvm.f32 function result supported");
|
|
||||||
auto *llvmTy = t.getUnderlyingType();
|
|
||||||
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
|
|
||||||
return make_string_error("only single llvm.f32 function result supported");
|
return make_string_error("only single llvm.f32 function result supported");
|
||||||
|
|
||||||
float res;
|
float res;
|
||||||
|
|
|
@ -25,6 +25,7 @@ add_llvm_library(MLIRTargetNVVMIR
|
||||||
target_link_libraries(MLIRTargetNVVMIR
|
target_link_libraries(MLIRTargetNVVMIR
|
||||||
MLIRGPU
|
MLIRGPU
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRLLVMIR
|
||||||
MLIRNVVMIR
|
MLIRNVVMIR
|
||||||
MLIRTargetLLVMIRModuleTranslation
|
MLIRTargetLLVMIRModuleTranslation
|
||||||
)
|
)
|
||||||
|
@ -39,6 +40,7 @@ add_llvm_library(MLIRTargetROCDLIR
|
||||||
target_link_libraries(MLIRTargetROCDLIR
|
target_link_libraries(MLIRTargetROCDLIR
|
||||||
MLIRGPU
|
MLIRGPU
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRLLVMIR
|
||||||
MLIRROCDLIR
|
MLIRROCDLIR
|
||||||
MLIRTargetLLVMIRModuleTranslation
|
MLIRTargetLLVMIRModuleTranslation
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include "mlir/Target/NVVMIR.h"
|
#include "mlir/Target/NVVMIR.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||||
#include "mlir/IR/Function.h"
|
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||||
#include "mlir/Translation.h"
|
#include "mlir/Translation.h"
|
||||||
|
@ -66,11 +66,13 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
|
||||||
ModuleTranslation translation(m);
|
ModuleTranslation translation(m);
|
||||||
auto llvmModule =
|
auto llvmModule =
|
||||||
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
||||||
|
if (!llvmModule)
|
||||||
|
return llvmModule;
|
||||||
|
|
||||||
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
||||||
// function as a kernel.
|
// function as a kernel.
|
||||||
for (FuncOp func : m.getOps<FuncOp>()) {
|
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
if (!gpu::GPUDialect::isKernel(func))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto *llvmFunc = llvmModule->getFunction(func.getName());
|
auto *llvmFunc = llvmModule->getFunction(func.getName());
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "mlir/Target/ROCDLIR.h"
|
#include "mlir/Target/ROCDLIR.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
|
@ -93,7 +94,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
|
||||||
// foreach GPU kernel
|
// foreach GPU kernel
|
||||||
// 1. Insert AMDGPU_KERNEL calling convention.
|
// 1. Insert AMDGPU_KERNEL calling convention.
|
||||||
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
// 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
|
||||||
for (FuncOp func : m.getOps<FuncOp>()) {
|
for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
|
||||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
|
@ -39,38 +39,6 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
|
|
||||||
// Convert an MLIR function type to LLVM IR. Arguments of the function must of
|
|
||||||
// MLIR LLVM IR dialect types. Use `loc` as a location when reporting errors.
|
|
||||||
// Return nullptr on errors.
|
|
||||||
static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
|
|
||||||
FunctionType type, Location loc,
|
|
||||||
bool isVarArgs) {
|
|
||||||
assert(type && "expected non-null type");
|
|
||||||
if (type.getNumResults() > 1)
|
|
||||||
return emitError(loc, "LLVM functions can only have 0 or 1 result"),
|
|
||||||
nullptr;
|
|
||||||
|
|
||||||
SmallVector<llvm::Type *, 8> argTypes;
|
|
||||||
argTypes.reserve(type.getNumInputs());
|
|
||||||
for (auto t : type.getInputs()) {
|
|
||||||
auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
|
|
||||||
if (!wrappedLLVMType)
|
|
||||||
return emitError(loc, "non-LLVM function argument type"), nullptr;
|
|
||||||
argTypes.push_back(wrappedLLVMType.getUnderlyingType());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type.getNumResults() == 0)
|
|
||||||
return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
|
|
||||||
isVarArgs);
|
|
||||||
|
|
||||||
auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
|
|
||||||
if (!wrappedResultType)
|
|
||||||
return emitError(loc, "non-LLVM function result"), nullptr;
|
|
||||||
|
|
||||||
return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
|
|
||||||
argTypes, isVarArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
|
// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
|
||||||
// This currently supports integer, floating point, splat and dense element
|
// This currently supports integer, floating point, splat and dense element
|
||||||
// attributes and combinations thereof. In case of error, report it to `loc`
|
// attributes and combinations thereof. In case of error, report it to `loc`
|
||||||
|
@ -362,7 +330,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
|
||||||
: terminator.getSuccessorOperand(1, index);
|
: terminator.getSuccessorOperand(1, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleTranslation::connectPHINodes(FuncOp func) {
|
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
|
||||||
// Skip the first block, it cannot be branched to and its arguments correspond
|
// Skip the first block, it cannot be branched to and its arguments correspond
|
||||||
// to the arguments of the LLVM function.
|
// to the arguments of the LLVM function.
|
||||||
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
|
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
|
||||||
|
@ -393,7 +361,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort function blocks topologically.
|
// Sort function blocks topologically.
|
||||||
static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
|
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
|
||||||
// For each blocks that has not been visited yet (i.e. that has no
|
// For each blocks that has not been visited yet (i.e. that has no
|
||||||
// predecessors), add it to the list and traverse its successors in DFS
|
// predecessors), add it to the list and traverse its successors in DFS
|
||||||
// preorder.
|
// preorder.
|
||||||
|
@ -407,7 +375,7 @@ static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
|
||||||
return blocks;
|
return blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
|
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
|
||||||
// Clear the block and value mappings, they are only relevant within one
|
// Clear the block and value mappings, they are only relevant within one
|
||||||
// function.
|
// function.
|
||||||
blockMapping.clear();
|
blockMapping.clear();
|
||||||
|
@ -460,24 +428,17 @@ LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
|
||||||
LogicalResult ModuleTranslation::convertFunctions() {
|
LogicalResult ModuleTranslation::convertFunctions() {
|
||||||
// Declare all functions first because there may be function calls that form a
|
// Declare all functions first because there may be function calls that form a
|
||||||
// call graph with cycles.
|
// call graph with cycles.
|
||||||
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
|
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||||
mlir::BoolAttr isVarArgsAttr =
|
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
|
||||||
function.getAttrOfType<BoolAttr>("std.varargs");
|
function.getName(),
|
||||||
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
|
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
|
||||||
llvm::FunctionType *functionType =
|
|
||||||
convertFunctionType(llvmModule->getContext(), function.getType(),
|
|
||||||
function.getLoc(), isVarArgs);
|
|
||||||
if (!functionType)
|
|
||||||
return failure();
|
|
||||||
llvm::FunctionCallee llvmFuncCst =
|
|
||||||
llvmModule->getOrInsertFunction(function.getName(), functionType);
|
|
||||||
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
|
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
|
||||||
functionMapping[function.getName()] =
|
functionMapping[function.getName()] =
|
||||||
cast<llvm::Function>(llvmFuncCst.getCallee());
|
cast<llvm::Function>(llvmFuncCst.getCallee());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert functions.
|
// Convert functions.
|
||||||
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
|
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
|
||||||
// Ignore external functions.
|
// Ignore external functions.
|
||||||
if (function.isExternal())
|
if (function.isExternal())
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -10,7 +10,7 @@ module attributes {gpu.container_module} {
|
||||||
attributes { gpu.kernel }
|
attributes { gpu.kernel }
|
||||||
}
|
}
|
||||||
|
|
||||||
func @foo() {
|
llvm.func @foo() {
|
||||||
%0 = "op"() : () -> (!llvm.float)
|
%0 = "op"() : () -> (!llvm.float)
|
||||||
%1 = "op"() : () -> (!llvm<"float*">)
|
%1 = "op"() : () -> (!llvm<"float*">)
|
||||||
%cst = constant 8 : index
|
%cst = constant 8 : index
|
||||||
|
@ -29,7 +29,7 @@ module attributes {gpu.container_module} {
|
||||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernel_module }
|
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernel_module }
|
||||||
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
|
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
|
||||||
|
|
||||||
return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
// RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s
|
// RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK: attributes {gpu.kernel_module, nvvm.cubin = "CUBIN"}
|
// CHECK: attributes {gpu.kernel_module, nvvm.cubin = "CUBIN"}
|
||||||
module @kernels attributes {gpu.kernel_module} {
|
module @foo attributes {gpu.kernel_module} {
|
||||||
func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
|
llvm.func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
|
||||||
|
// CHECK: attributes {gpu.kernel}
|
||||||
attributes { gpu.kernel } {
|
attributes { gpu.kernel } {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
@ -10,15 +11,15 @@ module @kernels attributes {gpu.kernel_module} {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
module attributes {gpu.kernel_module} {
|
module @bar attributes {gpu.kernel_module} {
|
||||||
// CHECK: func @kernel_a
|
// CHECK: func @kernel_a
|
||||||
func @kernel_a()
|
llvm.func @kernel_a()
|
||||||
attributes { gpu.kernel } {
|
attributes { gpu.kernel } {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func @kernel_b
|
// CHECK: func @kernel_b
|
||||||
func @kernel_b()
|
llvm.func @kernel_b()
|
||||||
attributes { gpu.kernel } {
|
attributes { gpu.kernel } {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @fmuladd_test
|
// CHECK-LABEL: @fmuladd_test
|
||||||
func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
|
llvm.func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">) {
|
||||||
// CHECK: call float @llvm.fmuladd.f32.f32.f32
|
// CHECK: call float @llvm.fmuladd.f32.f32.f32
|
||||||
"llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
|
"llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
|
||||||
// CHECK: call <8 x float> @llvm.fmuladd.v8f32.v8f32.v8f32
|
// CHECK: call <8 x float> @llvm.fmuladd.v8f32.v8f32.v8f32
|
||||||
|
@ -10,7 +10,7 @@ func @fmuladd_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x fl
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @exp_test
|
// CHECK-LABEL: @exp_test
|
||||||
func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
llvm.func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
||||||
// CHECK: call float @llvm.exp.f32
|
// CHECK: call float @llvm.exp.f32
|
||||||
"llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
|
"llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
|
||||||
// CHECK: call <8 x float> @llvm.exp.v8f32
|
// CHECK: call <8 x float> @llvm.exp.v8f32
|
||||||
|
|
|
@ -29,7 +29,7 @@ llvm.mlir.global @int_global_undef() : !llvm.i64
|
||||||
//
|
//
|
||||||
|
|
||||||
// CHECK: declare i8* @malloc(i64)
|
// CHECK: declare i8* @malloc(i64)
|
||||||
func @malloc(!llvm.i64) -> !llvm<"i8*">
|
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||||
// CHECK: declare void @free(i8*)
|
// CHECK: declare void @free(i8*)
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,12 +41,12 @@ func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||||
// CHECK-LABEL: define void @empty() {
|
// CHECK-LABEL: define void @empty() {
|
||||||
// CHECK-NEXT: ret void
|
// CHECK-NEXT: ret void
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
func @empty() {
|
llvm.func @empty() {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @global_refs
|
// CHECK-LABEL: @global_refs
|
||||||
func @global_refs() {
|
llvm.func @global_refs() {
|
||||||
// Check load from globals.
|
// Check load from globals.
|
||||||
// CHECK: load i32, i32* @i32_global
|
// CHECK: load i32, i32* @i32_global
|
||||||
%0 = llvm.mlir.addressof @i32_global : !llvm<"i32*">
|
%0 = llvm.mlir.addressof @i32_global : !llvm<"i32*">
|
||||||
|
@ -63,11 +63,11 @@ func @global_refs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: declare void @body(i64)
|
// CHECK-LABEL: declare void @body(i64)
|
||||||
func @body(!llvm.i64)
|
llvm.func @body(!llvm.i64)
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: define void @simple_loop() {
|
// CHECK-LABEL: define void @simple_loop() {
|
||||||
func @simple_loop() {
|
llvm.func @simple_loop() {
|
||||||
// CHECK: br label %[[SIMPLE_bb1:[0-9]+]]
|
// CHECK: br label %[[SIMPLE_bb1:[0-9]+]]
|
||||||
llvm.br ^bb1
|
llvm.br ^bb1
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ func @simple_loop() {
|
||||||
// CHECK-NEXT: call void @simple_loop()
|
// CHECK-NEXT: call void @simple_loop()
|
||||||
// CHECK-NEXT: ret void
|
// CHECK-NEXT: ret void
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
func @simple_caller() {
|
llvm.func @simple_caller() {
|
||||||
llvm.call @simple_loop() : () -> ()
|
llvm.call @simple_loop() : () -> ()
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
@ -124,20 +124,20 @@ func @simple_caller() {
|
||||||
// CHECK-NEXT: call void @more_imperfectly_nested_loops()
|
// CHECK-NEXT: call void @more_imperfectly_nested_loops()
|
||||||
// CHECK-NEXT: ret void
|
// CHECK-NEXT: ret void
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
func @ml_caller() {
|
llvm.func @ml_caller() {
|
||||||
llvm.call @simple_loop() : () -> ()
|
llvm.call @simple_loop() : () -> ()
|
||||||
llvm.call @more_imperfectly_nested_loops() : () -> ()
|
llvm.call @more_imperfectly_nested_loops() : () -> ()
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: declare i64 @body_args(i64)
|
// CHECK-LABEL: declare i64 @body_args(i64)
|
||||||
func @body_args(!llvm.i64) -> !llvm.i64
|
llvm.func @body_args(!llvm.i64) -> !llvm.i64
|
||||||
// CHECK-LABEL: declare i32 @other(i64, i32)
|
// CHECK-LABEL: declare i32 @other(i64, i32)
|
||||||
func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
|
llvm.func @other(!llvm.i64, !llvm.i32) -> !llvm.i32
|
||||||
|
|
||||||
// CHECK-LABEL: define i32 @func_args(i32 {{%.*}}, i32 {{%.*}}) {
|
// CHECK-LABEL: define i32 @func_args(i32 {{%.*}}, i32 {{%.*}}) {
|
||||||
// CHECK-NEXT: br label %[[ARGS_bb1:[0-9]+]]
|
// CHECK-NEXT: br label %[[ARGS_bb1:[0-9]+]]
|
||||||
func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
llvm.func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
||||||
%0 = llvm.mlir.constant(0 : i32) : !llvm.i32
|
%0 = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
llvm.br ^bb1
|
llvm.br ^bb1
|
||||||
|
|
||||||
|
@ -182,17 +182,17 @@ func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: declare void @pre(i64)
|
// CHECK: declare void @pre(i64)
|
||||||
func @pre(!llvm.i64)
|
llvm.func @pre(!llvm.i64)
|
||||||
|
|
||||||
// CHECK: declare void @body2(i64, i64)
|
// CHECK: declare void @body2(i64, i64)
|
||||||
func @body2(!llvm.i64, !llvm.i64)
|
llvm.func @body2(!llvm.i64, !llvm.i64)
|
||||||
|
|
||||||
// CHECK: declare void @post(i64)
|
// CHECK: declare void @post(i64)
|
||||||
func @post(!llvm.i64)
|
llvm.func @post(!llvm.i64)
|
||||||
|
|
||||||
// CHECK-LABEL: define void @imperfectly_nested_loops() {
|
// CHECK-LABEL: define void @imperfectly_nested_loops() {
|
||||||
// CHECK-NEXT: br label %[[IMPER_bb1:[0-9]+]]
|
// CHECK-NEXT: br label %[[IMPER_bb1:[0-9]+]]
|
||||||
func @imperfectly_nested_loops() {
|
llvm.func @imperfectly_nested_loops() {
|
||||||
llvm.br ^bb1
|
llvm.br ^bb1
|
||||||
|
|
||||||
// CHECK: [[IMPER_bb1]]:
|
// CHECK: [[IMPER_bb1]]:
|
||||||
|
@ -259,10 +259,10 @@ func @imperfectly_nested_loops() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: declare void @mid(i64)
|
// CHECK: declare void @mid(i64)
|
||||||
func @mid(!llvm.i64)
|
llvm.func @mid(!llvm.i64)
|
||||||
|
|
||||||
// CHECK: declare void @body3(i64, i64)
|
// CHECK: declare void @body3(i64, i64)
|
||||||
func @body3(!llvm.i64, !llvm.i64)
|
llvm.func @body3(!llvm.i64, !llvm.i64)
|
||||||
|
|
||||||
// A complete function transformation check.
|
// A complete function transformation check.
|
||||||
// CHECK-LABEL: define void @more_imperfectly_nested_loops() {
|
// CHECK-LABEL: define void @more_imperfectly_nested_loops() {
|
||||||
|
@ -306,7 +306,7 @@ func @body3(!llvm.i64, !llvm.i64)
|
||||||
// CHECK: 21: ; preds = %2
|
// CHECK: 21: ; preds = %2
|
||||||
// CHECK-NEXT: ret void
|
// CHECK-NEXT: ret void
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
func @more_imperfectly_nested_loops() {
|
llvm.func @more_imperfectly_nested_loops() {
|
||||||
llvm.br ^bb1
|
llvm.br ^bb1
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%0 = llvm.mlir.constant(0 : index) : !llvm.i64
|
%0 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
|
@ -359,7 +359,7 @@ func @more_imperfectly_nested_loops() {
|
||||||
//
|
//
|
||||||
|
|
||||||
// CHECK-LABEL: define void @memref_alloc()
|
// CHECK-LABEL: define void @memref_alloc()
|
||||||
func @memref_alloc() {
|
llvm.func @memref_alloc() {
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 400)
|
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 400)
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = insertvalue { float* } undef, float* %{{[0-9]+}}, 0
|
// CHECK-NEXT: %{{[0-9]+}} = insertvalue { float* } undef, float* %{{[0-9]+}}, 0
|
||||||
|
@ -377,10 +377,10 @@ func @memref_alloc() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: declare i64 @get_index()
|
// CHECK-LABEL: declare i64 @get_index()
|
||||||
func @get_index() -> !llvm.i64
|
llvm.func @get_index() -> !llvm.i64
|
||||||
|
|
||||||
// CHECK-LABEL: define void @store_load_static()
|
// CHECK-LABEL: define void @store_load_static()
|
||||||
func @store_load_static() {
|
llvm.func @store_load_static() {
|
||||||
^bb0:
|
^bb0:
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 40)
|
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 40)
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||||
|
@ -448,7 +448,7 @@ func @store_load_static() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define void @store_load_dynamic(i64 {{%.*}})
|
// CHECK-LABEL: define void @store_load_dynamic(i64 {{%.*}})
|
||||||
func @store_load_dynamic(%arg0: !llvm.i64) {
|
llvm.func @store_load_dynamic(%arg0: !llvm.i64) {
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 %{{[0-9]+}})
|
// CHECK-NEXT: %{{[0-9]+}} = call i8* @malloc(i64 %{{[0-9]+}})
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
// CHECK-NEXT: %{{[0-9]+}} = bitcast i8* %{{[0-9]+}} to float*
|
||||||
|
@ -518,7 +518,7 @@ func @store_load_dynamic(%arg0: !llvm.i64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define void @store_load_mixed(i64 {{%.*}})
|
// CHECK-LABEL: define void @store_load_mixed(i64 {{%.*}})
|
||||||
func @store_load_mixed(%arg0: !llvm.i64) {
|
llvm.func @store_load_mixed(%arg0: !llvm.i64) {
|
||||||
%0 = llvm.mlir.constant(10 : index) : !llvm.i64
|
%0 = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 2, %{{[0-9]+}}
|
// CHECK-NEXT: %{{[0-9]+}} = mul i64 2, %{{[0-9]+}}
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
// CHECK-NEXT: %{{[0-9]+}} = mul i64 %{{[0-9]+}}, 4
|
||||||
|
@ -603,7 +603,7 @@ func @store_load_mixed(%arg0: !llvm.i64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define { float*, i64 } @memref_args_rets({ float* } {{%.*}}, { float*, i64 } {{%.*}}, { float*, i64 } {{%.*}}) {
|
// CHECK-LABEL: define { float*, i64 } @memref_args_rets({ float* } {{%.*}}, { float*, i64 } {{%.*}}, { float*, i64 } {{%.*}}) {
|
||||||
func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
|
llvm.func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }">, %arg2: !llvm<"{ float*, i64 }">) -> !llvm<"{ float*, i64 }"> {
|
||||||
%0 = llvm.mlir.constant(7 : index) : !llvm.i64
|
%0 = llvm.mlir.constant(7 : index) : !llvm.i64
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = call i64 @get_index()
|
// CHECK-NEXT: %{{[0-9]+}} = call i64 @get_index()
|
||||||
%1 = llvm.call @get_index() : () -> !llvm.i64
|
%1 = llvm.call @get_index() : () -> !llvm.i64
|
||||||
|
@ -657,7 +657,7 @@ func @memref_args_rets(%arg0: !llvm<"{ float* }">, %arg1: !llvm<"{ float*, i64 }
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: define i64 @memref_dim({ float*, i64, i64 } {{%.*}})
|
// CHECK-LABEL: define i64 @memref_dim({ float*, i64, i64 } {{%.*}})
|
||||||
func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
llvm.func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
||||||
// Expecting this to create an LLVM constant.
|
// Expecting this to create an LLVM constant.
|
||||||
%0 = llvm.mlir.constant(42 : index) : !llvm.i64
|
%0 = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||||
// CHECK-NEXT: %2 = extractvalue { float*, i64, i64 } %0, 1
|
// CHECK-NEXT: %2 = extractvalue { float*, i64, i64 } %0, 1
|
||||||
|
@ -678,12 +678,12 @@ func @memref_dim(%arg0: !llvm<"{ float*, i64, i64 }">) -> !llvm.i64 {
|
||||||
llvm.return %6 : !llvm.i64
|
llvm.return %6 : !llvm.i64
|
||||||
}
|
}
|
||||||
|
|
||||||
func @get_i64() -> !llvm.i64
|
llvm.func @get_i64() -> !llvm.i64
|
||||||
func @get_f32() -> !llvm.float
|
llvm.func @get_f32() -> !llvm.float
|
||||||
func @get_memref() -> !llvm<"{ float*, i64, i64 }">
|
llvm.func @get_memref() -> !llvm<"{ float*, i64, i64 }">
|
||||||
|
|
||||||
// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
|
// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
|
||||||
func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
llvm.func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
||||||
%0 = llvm.call @get_i64() : () -> !llvm.i64
|
%0 = llvm.call @get_i64() : () -> !llvm.i64
|
||||||
%1 = llvm.call @get_f32() : () -> !llvm.float
|
%1 = llvm.call @get_f32() : () -> !llvm.float
|
||||||
%2 = llvm.call @get_memref() : () -> !llvm<"{ float*, i64, i64 }">
|
%2 = llvm.call @get_memref() : () -> !llvm<"{ float*, i64, i64 }">
|
||||||
|
@ -700,7 +700,7 @@ func @multireturn() -> !llvm<"{ i64, float, { float*, i64, i64 } }"> {
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: define void @multireturn_caller() {
|
// CHECK-LABEL: define void @multireturn_caller() {
|
||||||
func @multireturn_caller() {
|
llvm.func @multireturn_caller() {
|
||||||
// CHECK-NEXT: %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
|
// CHECK-NEXT: %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
|
||||||
// CHECK-NEXT: [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
|
// CHECK-NEXT: [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
|
||||||
// CHECK-NEXT: [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
|
// CHECK-NEXT: [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
|
||||||
|
@ -734,7 +734,7 @@ func @multireturn_caller() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define <4 x float> @vector_ops(<4 x float> {{%.*}}, <4 x i1> {{%.*}}, <4 x i64> {{%.*}}) {
|
// CHECK-LABEL: define <4 x float> @vector_ops(<4 x float> {{%.*}}, <4 x i1> {{%.*}}, <4 x i64> {{%.*}}) {
|
||||||
func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
|
llvm.func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
|
||||||
%0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
|
%0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
|
||||||
// CHECK-NEXT: %4 = fadd <4 x float> %0, <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
|
// CHECK-NEXT: %4 = fadd <4 x float> %0, <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
|
||||||
%1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
|
%1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
|
||||||
|
@ -763,7 +763,7 @@ func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @ops
|
// CHECK-LABEL: @ops
|
||||||
func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
|
llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
|
||||||
// CHECK-NEXT: fsub float %0, %1
|
// CHECK-NEXT: fsub float %0, %1
|
||||||
%0 = llvm.fsub %arg0, %arg1 : !llvm.float
|
%0 = llvm.fsub %arg0, %arg1 : !llvm.float
|
||||||
// CHECK-NEXT: %6 = sub i32 %2, %3
|
// CHECK-NEXT: %6 = sub i32 %2, %3
|
||||||
|
@ -811,7 +811,7 @@ func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm
|
||||||
//
|
//
|
||||||
|
|
||||||
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) {
|
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}}) {
|
||||||
func @indirect_const_call(%arg0: !llvm.i64) {
|
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
|
||||||
// CHECK-NEXT: call void @body(i64 %0)
|
// CHECK-NEXT: call void @body(i64 %0)
|
||||||
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
|
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
|
||||||
llvm.call %0(%arg0) : (!llvm.i64) -> ()
|
llvm.call %0(%arg0) : (!llvm.i64) -> ()
|
||||||
|
@ -820,7 +820,7 @@ func @indirect_const_call(%arg0: !llvm.i64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define i32 @indirect_call(i32 (float)* {{%.*}}, float {{%.*}}) {
|
// CHECK-LABEL: define i32 @indirect_call(i32 (float)* {{%.*}}, float {{%.*}}) {
|
||||||
func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
|
llvm.func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i32 {
|
||||||
// CHECK-NEXT: %3 = call i32 %0(float %1)
|
// CHECK-NEXT: %3 = call i32 %0(float %1)
|
||||||
%0 = llvm.call %arg0(%arg1) : (!llvm.float) -> !llvm.i32
|
%0 = llvm.call %arg0(%arg1) : (!llvm.float) -> !llvm.i32
|
||||||
// CHECK-NEXT: ret i32 %3
|
// CHECK-NEXT: ret i32 %3
|
||||||
|
@ -833,7 +833,7 @@ func @indirect_call(%arg0: !llvm<"i32 (float)*">, %arg1: !llvm.float) -> !llvm.i
|
||||||
//
|
//
|
||||||
|
|
||||||
// CHECK-LABEL: define void @cond_br_arguments(i1 {{%.*}}, i1 {{%.*}}) {
|
// CHECK-LABEL: define void @cond_br_arguments(i1 {{%.*}}, i1 {{%.*}}) {
|
||||||
func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
llvm.func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
||||||
// CHECK-NEXT: br i1 %0, label %3, label %5
|
// CHECK-NEXT: br i1 %0, label %3, label %5
|
||||||
llvm.cond_br %arg0, ^bb1(%arg0 : !llvm.i1), ^bb2
|
llvm.cond_br %arg0, ^bb1(%arg0 : !llvm.i1), ^bb2
|
||||||
|
|
||||||
|
@ -850,15 +850,14 @@ func @cond_br_arguments(%arg0: !llvm.i1, %arg1: !llvm.i1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define void @llvm_noalias(float* noalias {{%*.}}) {
|
// CHECK-LABEL: define void @llvm_noalias(float* noalias {{%*.}}) {
|
||||||
func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
|
llvm.func @llvm_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @llvm_varargs(...)
|
// CHECK-LABEL: @llvm_varargs(...)
|
||||||
func @llvm_varargs()
|
llvm.func @llvm_varargs(...)
|
||||||
attributes {std.varargs = true}
|
|
||||||
|
|
||||||
func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
llvm.func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
||||||
// CHECK: %2 = inttoptr i32 %0 to i32*
|
// CHECK: %2 = inttoptr i32 %0 to i32*
|
||||||
// CHECK-NEXT: %3 = ptrtoint i32* %2 to i32
|
// CHECK-NEXT: %3 = ptrtoint i32* %2 to i32
|
||||||
%1 = llvm.inttoptr %arg0 : !llvm.i32 to !llvm<"i32*">
|
%1 = llvm.inttoptr %arg0 : !llvm.i32 to !llvm<"i32*">
|
||||||
|
@ -866,19 +865,19 @@ func @intpointerconversion(%arg0 : !llvm.i32) -> !llvm.i32 {
|
||||||
llvm.return %2 : !llvm.i32
|
llvm.return %2 : !llvm.i32
|
||||||
}
|
}
|
||||||
|
|
||||||
func @stringconstant() -> !llvm<"i8*"> {
|
llvm.func @stringconstant() -> !llvm<"i8*"> {
|
||||||
%1 = llvm.mlir.constant("Hello world!") : !llvm<"i8*">
|
%1 = llvm.mlir.constant("Hello world!") : !llvm<"i8*">
|
||||||
// CHECK: ret [12 x i8] c"Hello world!"
|
// CHECK: ret [12 x i8] c"Hello world!"
|
||||||
llvm.return %1 : !llvm<"i8*">
|
llvm.return %1 : !llvm<"i8*">
|
||||||
}
|
}
|
||||||
|
|
||||||
func @noreach() {
|
llvm.func @noreach() {
|
||||||
// CHECK: unreachable
|
// CHECK: unreachable
|
||||||
llvm.unreachable
|
llvm.unreachable
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: define void @fcmp
|
// CHECK-LABEL: define void @fcmp
|
||||||
func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
llvm.func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
||||||
// CHECK: fcmp oeq float %0, %1
|
// CHECK: fcmp oeq float %0, %1
|
||||||
// CHECK-NEXT: fcmp ogt float %0, %1
|
// CHECK-NEXT: fcmp ogt float %0, %1
|
||||||
// CHECK-NEXT: fcmp oge float %0, %1
|
// CHECK-NEXT: fcmp oge float %0, %1
|
||||||
|
@ -911,7 +910,7 @@ func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @vect
|
// CHECK-LABEL: @vect
|
||||||
func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
llvm.func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
||||||
// CHECK-NEXT: extractelement <4 x float> {{.*}}, i32 {{.*}}
|
// CHECK-NEXT: extractelement <4 x float> {{.*}}, i32 {{.*}}
|
||||||
// CHECK-NEXT: insertelement <4 x float> {{.*}}, float %2, i32 {{.*}}
|
// CHECK-NEXT: insertelement <4 x float> {{.*}}, float %2, i32 {{.*}}
|
||||||
// CHECK-NEXT: shufflevector <4 x float> {{.*}}, <4 x float> {{.*}}, <5 x i32> <i32 0, i32 0, i32 0, i32 0, i32 7>
|
// CHECK-NEXT: shufflevector <4 x float> {{.*}}, <4 x float> {{.*}}, <5 x i32> <i32 0, i32 0, i32 0, i32 0, i32 7>
|
||||||
|
@ -922,7 +921,7 @@ func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @alloca
|
// CHECK-LABEL: @alloca
|
||||||
func @alloca(%size : !llvm.i64) {
|
llvm.func @alloca(%size : !llvm.i64) {
|
||||||
// CHECK: alloca
|
// CHECK: alloca
|
||||||
// CHECK-NOT: align
|
// CHECK-NOT: align
|
||||||
llvm.alloca %size x !llvm.i32 {alignment = 0} : (!llvm.i64) -> (!llvm<"i32*">)
|
llvm.alloca %size x !llvm.i32 {alignment = 0} : (!llvm.i64) -> (!llvm<"i32*">)
|
||||||
|
@ -932,13 +931,14 @@ func @alloca(%size : !llvm.i64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @constants
|
// CHECK-LABEL: @constants
|
||||||
func @constants() -> !llvm<"<4 x float>"> {
|
llvm.func @constants() -> !llvm<"<4 x float>"> {
|
||||||
// CHECK: ret <4 x float> <float 4.2{{0*}}e+01, float 0.{{0*}}e+00, float 0.{{0*}}e+00, float 0.{{0*}}e+00>
|
// CHECK: ret <4 x float> <float 4.2{{0*}}e+01, float 0.{{0*}}e+00, float 0.{{0*}}e+00, float 0.{{0*}}e+00>
|
||||||
%0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : !llvm<"<4 x float>">
|
%0 = llvm.mlir.constant(sparse<[[0]], [4.2e+01]> : vector<4xf32>) : !llvm<"<4 x float>">
|
||||||
llvm.return %0 : !llvm<"<4 x float>">
|
llvm.return %0 : !llvm<"<4 x float>">
|
||||||
}
|
}
|
||||||
|
|
||||||
func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
// CHECK-LABEL: @fp_casts
|
||||||
|
llvm.func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
||||||
// CHECK: fptrunc double {{.*}} to float
|
// CHECK: fptrunc double {{.*}} to float
|
||||||
%a = llvm.fptrunc %fp2 : !llvm<"double"> to !llvm<"float">
|
%a = llvm.fptrunc %fp2 : !llvm<"double"> to !llvm<"float">
|
||||||
// CHECK: fpext float {{.*}} to double
|
// CHECK: fpext float {{.*}} to double
|
||||||
|
@ -948,7 +948,8 @@ func @fp_casts(%fp1 : !llvm<"float">, %fp2 : !llvm<"double">) -> !llvm.i16 {
|
||||||
llvm.return %c : !llvm.i16
|
llvm.return %c : !llvm.i16
|
||||||
}
|
}
|
||||||
|
|
||||||
func @integer_extension_and_truncation(%a : !llvm.i32) {
|
// CHECK-LABEL: @integer_extension_and_truncation
|
||||||
|
llvm.func @integer_extension_and_truncation(%a : !llvm.i32) {
|
||||||
// CHECK: sext i32 {{.*}} to i64
|
// CHECK: sext i32 {{.*}} to i64
|
||||||
// CHECK: zext i32 {{.*}} to i64
|
// CHECK: zext i32 {{.*}} to i64
|
||||||
// CHECK: trunc i32 {{.*}} to i16
|
// CHECK: trunc i32 {{.*}} to i16
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
|
// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
|
||||||
|
|
||||||
func @nvvm_special_regs() -> !llvm.i32 {
|
llvm.func @nvvm_special_regs() -> !llvm.i32 {
|
||||||
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
// CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||||
%1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
%1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
|
||||||
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
|
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
|
||||||
%2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
%2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
|
||||||
|
@ -32,13 +32,13 @@ func @nvvm_special_regs() -> !llvm.i32 {
|
||||||
llvm.return %1 : !llvm.i32
|
llvm.return %1 : !llvm.i32
|
||||||
}
|
}
|
||||||
|
|
||||||
func @llvm.nvvm.barrier0() {
|
llvm.func @llvm.nvvm.barrier0() {
|
||||||
// CHECK: call void @llvm.nvvm.barrier0()
|
// CHECK: call void @llvm.nvvm.barrier0()
|
||||||
nvvm.barrier0
|
nvvm.barrier0
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @nvvm_shfl(
|
llvm.func @nvvm_shfl(
|
||||||
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
|
||||||
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
|
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
|
||||||
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
|
||||||
|
@ -48,7 +48,7 @@ func @nvvm_shfl(
|
||||||
llvm.return %6 : !llvm.i32
|
llvm.return %6 : !llvm.i32
|
||||||
}
|
}
|
||||||
|
|
||||||
func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
llvm.func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
||||||
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
|
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
|
||||||
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
|
%3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
|
||||||
llvm.return %3 : !llvm.i32
|
llvm.return %3 : !llvm.i32
|
||||||
|
@ -56,7 +56,7 @@ func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
|
||||||
|
|
||||||
// This function has the "kernel" attribute attached and should appear in the
|
// This function has the "kernel" attribute attached and should appear in the
|
||||||
// NVVM annotations after conversion.
|
// NVVM annotations after conversion.
|
||||||
func @kernel_func() attributes {gpu.kernel} {
|
llvm.func @kernel_func() attributes {gpu.kernel} {
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// RUN: mlir-translate -mlir-to-rocdlir %s | FileCheck %s
|
// RUN: mlir-translate -mlir-to-rocdlir %s | FileCheck %s
|
||||||
|
|
||||||
func @rocdl_special_regs() -> !llvm.i32 {
|
llvm.func @rocdl_special_regs() -> !llvm.i32 {
|
||||||
// CHECK-LABEL: rocdl_special_regs
|
// CHECK-LABEL: rocdl_special_regs
|
||||||
// CHECK: call i32 @llvm.amdgcn.workitem.id.x()
|
// CHECK: call i32 @llvm.amdgcn.workitem.id.x()
|
||||||
%1 = rocdl.workitem.id.x : !llvm.i32
|
%1 = rocdl.workitem.id.x : !llvm.i32
|
||||||
|
@ -29,7 +29,7 @@ func @rocdl_special_regs() -> !llvm.i32 {
|
||||||
llvm.return %1 : !llvm.i32
|
llvm.return %1 : !llvm.i32
|
||||||
}
|
}
|
||||||
|
|
||||||
func @kernel_func() attributes {gpu.kernel} {
|
llvm.func @kernel_func() attributes {gpu.kernel} {
|
||||||
// CHECK-LABEL: amdgpu_kernel void @kernel_func
|
// CHECK-LABEL: amdgpu_kernel void @kernel_func
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,12 +12,12 @@
|
||||||
// RUN: rm %T/test.o
|
// RUN: rm %T/test.o
|
||||||
|
|
||||||
// Declarations of C library functions.
|
// Declarations of C library functions.
|
||||||
func @fabsf(!llvm.float) -> !llvm.float
|
llvm.func @fabsf(!llvm.float) -> !llvm.float
|
||||||
func @malloc(!llvm.i64) -> !llvm<"i8*">
|
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
|
||||||
func @free(!llvm<"i8*">)
|
llvm.func @free(!llvm<"i8*">)
|
||||||
|
|
||||||
// Check that a simple function with a nested call works.
|
// Check that a simple function with a nested call works.
|
||||||
func @main() -> !llvm.float {
|
llvm.func @main() -> !llvm.float {
|
||||||
%0 = llvm.mlir.constant(-4.200000e+02 : f32) : !llvm.float
|
%0 = llvm.mlir.constant(-4.200000e+02 : f32) : !llvm.float
|
||||||
%1 = llvm.call @fabsf(%0) : (!llvm.float) -> !llvm.float
|
%1 = llvm.call @fabsf(%0) : (!llvm.float) -> !llvm.float
|
||||||
llvm.return %1 : !llvm.float
|
llvm.return %1 : !llvm.float
|
||||||
|
@ -25,13 +25,13 @@ func @main() -> !llvm.float {
|
||||||
// CHECK: 4.200000e+02
|
// CHECK: 4.200000e+02
|
||||||
|
|
||||||
// Helper typed functions wrapping calls to "malloc" and "free".
|
// Helper typed functions wrapping calls to "malloc" and "free".
|
||||||
func @allocation() -> !llvm<"float*"> {
|
llvm.func @allocation() -> !llvm<"float*"> {
|
||||||
%0 = llvm.mlir.constant(4 : index) : !llvm.i64
|
%0 = llvm.mlir.constant(4 : index) : !llvm.i64
|
||||||
%1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*">
|
%1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*">
|
||||||
%2 = llvm.bitcast %1 : !llvm<"i8*"> to !llvm<"float*">
|
%2 = llvm.bitcast %1 : !llvm<"i8*"> to !llvm<"float*">
|
||||||
llvm.return %2 : !llvm<"float*">
|
llvm.return %2 : !llvm<"float*">
|
||||||
}
|
}
|
||||||
func @deallocation(%arg0: !llvm<"float*">) {
|
llvm.func @deallocation(%arg0: !llvm<"float*">) {
|
||||||
%0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*">
|
%0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*">
|
||||||
llvm.call @free(%0) : (!llvm<"i8*">) -> ()
|
llvm.call @free(%0) : (!llvm<"i8*">) -> ()
|
||||||
llvm.return
|
llvm.return
|
||||||
|
@ -39,7 +39,7 @@ func @deallocation(%arg0: !llvm<"float*">) {
|
||||||
|
|
||||||
// Check that allocation and deallocation works, and that a custom entry point
|
// Check that allocation and deallocation works, and that a custom entry point
|
||||||
// works.
|
// works.
|
||||||
func @foo() -> !llvm.float {
|
llvm.func @foo() -> !llvm.float {
|
||||||
%0 = llvm.call @allocation() : () -> !llvm<"float*">
|
%0 = llvm.call @allocation() : () -> !llvm<"float*">
|
||||||
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
|
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
%2 = llvm.mlir.constant(1.234000e+03 : f32) : !llvm.float
|
%2 = llvm.mlir.constant(1.234000e+03 : f32) : !llvm.float
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
|
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_0d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-0D
|
||||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
|
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_1d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-1D
|
||||||
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
|
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm -lower-to-llvm | mlir-cpu-runner -e print_3d -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext | FileCheck %s --check-prefix=PRINT-3D
|
||||||
|
|
||||||
func @print_0d() {
|
func @print_0d() {
|
||||||
%f = constant 2.00000e+00 : f32
|
%f = constant 2.00000e+00 : f32
|
||||||
|
|
Loading…
Reference in New Issue