[mlir] use unpacked memref descriptors at function boundaries

The existing (default) calling convention for memrefs in standard-to-LLVM
conversion was motivated by interfacing with LLVM IR produced from C sources.
In particular, it passes a pointer to the memref descriptor structure when
calling the function. Therefore, the descriptor is allocated on stack before
the call. This convention leads to several problems. PR44644 indicates a
problem with stack exhaustion when calling functions with memref-typed
arguments in a loop. Allocating outside of the loop may lead to concurrent
access problems in case the loop is parallel. When targeting GPUs, the contents
of the stack-allocated memory for the descriptor (passed by pointer) needs to
be explicitly copied to the device. Using an aggregate type makes it impossible
to attach pointer-specific argument attributes pertaining to alignment and
aliasing in the LLVM dialect.

Change the default calling convention for memrefs in standard-to-LLVM
conversion to transform a memref into a list of arguments, each of primitive
type, that are comprised in the memref descriptor. This avoids stack allocation
for ranked memrefs (and thus stack exhaustion and potential concurrent access
problems) and simplifies the device function invocation on GPUs.

Provide an option in the standard-to-LLVM conversion to generate auxiliary
wrapper function with the same interface as the previous calling convention,
compatible with LLVM IR porduced from C sources. These auxiliary functions
pack the individual values into a descriptor structure or unpack it. They also
handle descriptor stack allocation if necessary, serving as an allocation
scope: the memory reserved by `alloca` will be freed on exiting the auxiliary
function.

The effect of this change on MLIR-generated only LLVM IR is minimal. When
interfacing MLIR-generated LLVM IR with C-generated LLVM IR, the integration
only needs to require auxiliary functions and change the function name to call
the wrapper function instead of the original function.

This also opens the door to forwarding aliasing and alignment information from
memrefs to LLVM IR pointers in the standrd-to-LLVM conversion.
This commit is contained in:
Alex Zinenko 2020-02-10 14:12:47 +01:00
parent 1dc62d0358
commit 5a1778057f
25 changed files with 1148 additions and 389 deletions

View File

@ -248,58 +248,123 @@ func @bar() {
### Calling Convention for `memref`
For function _arguments_ of `memref` type, ranked or unranked, the type of the
argument is a _pointer_ to the memref descriptor type defined above. The caller
of such function is required to store the descriptor in memory and guarantee
that the storage remains live until the callee returns. The caller can than pass
the pointer to that memory as function argument. The callee loads from the
pointers it was passed as arguments in the entry block of the function, making
the descriptor passed in as argument available for use similarly to
ocally-defined descriptors.
Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a
list of arguments of non-aggregate types that the memref descriptor defined
above comprises. That is, the outer struct type and the inner array types are
replaced with individual arguments.
This convention is implemented in the conversion of `std.func` and `std.call` to
the LLVM dialect. Conversions from other dialects should take it into account.
The motivation for this convention is to simplify the ABI for interfacing with
other LLVM modules, in particular those generated from C sources, while avoiding
platform-specific aspects until MLIR has a proper ABI modeling.
the LLVM dialect, with the former unpacking the descriptor into a set of
individual values and the latter packing those values back into a descriptor so
as to make it transparently usable by other operations. Conversions from other
dialects should take this convention into account.
Example:
This specific convention is motivated by the necessity to specify alignment and
aliasing attributes on the raw pointers underpinning the memref.
Examples:
```mlir
func @foo(memref<?xf32>) -> () {
%c0 = constant 0 : index
load %arg0[%c0] : memref<?xf32>
func @foo(%arg0: memref<?xf32>) -> () {
"use"(%arg0) : (memref<?xf32>) -> ()
return
}
func @bar(%arg0: index) {
%0 = alloc(%arg0) : memref<?xf32>
call @foo(%0) : (memref<?xf32>)-> ()
// Gets converted to the following.
llvm.func @foo(%arg0: !llvm<"float*">, // Allocated pointer.
%arg1: !llvm<"float*">, // Aligned pointer.
%arg2: !llvm.i64, // Offset.
%arg3: !llvm.i64, // Size in dim 0.
%arg4: !llvm.i64) { // Stride in dim 0.
// Populate memref descriptor structure.
%0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// Descriptor is now usable as a single value.
"use"(%5) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">) -> ()
llvm.return
}
```
```mlir
func @bar() {
%0 = "get"() : () -> (memref<?xf32>)
call @foo(%0) : (memref<?xf32>) -> ()
return
}
// Gets converted to the following IR.
// Accepts a pointer to the memref descriptor.
llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) {
// Loads the descriptor so that it can be used similarly to locally
// created descriptors.
%0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
// Gets converted to the following.
llvm.func @bar() {
%0 = "get"() : () -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// Unpack the memref descriptor.
%1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
%5 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// Pass individual values to the callee.
llvm.call @foo(%1, %2, %3, %4, %5) : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
llvm.return
}
llvm.func @bar(%arg0: !llvm.i64) {
// ... Allocation ...
// Definition of the descriptor.
%7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// ... Filling in the descriptor ...
%14 = // The final value of the allocated descriptor.
// Allocate the memory for the descriptor and store it.
%15 = llvm.mlir.constant(1 : index) : !llvm.i64
%16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
: (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
// Pass the pointer to the function.
llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
```
For **unranked** memrefs, the list of function arguments always contains two
elements, same as the unranked memref descriptor: an integer rank, and a
type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that
while the _calling convention_ does not require stack allocation, _casting_ to
unranked memref does since one cannot take an address of an SSA value containing
the ranked memref. The caller is in charge of ensuring the thread safety and
eventually removing unnecessary stack allocations in cast operations.
Example
```mlir
llvm.func @foo(%arg0: memref<*xf32>) -> () {
"use"(%arg0) : (memref<*xf32>) -> ()
return
}
// Gets converted to the following.
llvm.func @foo(%arg0: !llvm.i64 // Rank.
%arg1: !llvm<"i8*">) { // Type-erased pointer to descriptor.
// Pack the unranked memref descriptor.
%0 = llvm.mlir.undef : !llvm<"{ i64, i8* }">
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i8* }">
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i8* }">
"use"(%2) : (!llvm<"{ i64, i8* }">) -> ()
llvm.return
}
```
```mlir
llvm.func @bar() {
%0 = "get"() : () -> (memref<*xf32>)
call @foo(%0): (memref<*xf32>) -> ()
return
}
// Gets converted to the following.
llvm.func @bar() {
%0 = "get"() : () -> (!llvm<"{ i64, i8* }">)
// Unpack the memref descriptor.
%1 = llvm.extractvalue %0[0] : !llvm<"{ i64, i8* }">
%2 = llvm.extractvalue %0[1] : !llvm<"{ i64, i8* }">
// Pass individual values to the callee.
llvm.call @foo(%1, %2) : (!llvm.i64, !llvm<"i8*">)
llvm.return
}
```
@ -307,6 +372,141 @@ llvm.func @bar(%arg0: !llvm.i64) {
*This convention may or may not apply if the conversion of MemRef types is
overridden by the user.*
### C-compatible wrapper emission
In practical cases, it may be desirable to have externally-facing functions
with a single attribute corresponding to a MemRef argument. When interfacing
with LLVM IR produced from C, the code needs to respect the corresponding
calling convention. The conversion to the LLVM dialect provides an option to
generate wrapper functions that take memref descriptors as pointers-to-struct
compatible with data types produced by Clang when compiling C sources.
More specifically, a memref argument is converted into a pointer-to-struct
argument of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function, where
`T` is the converted element type and `N` is the memref rank. This type is
compatible with that produced by Clang for the following C++ structure template
instantiations or their equivalents in C.
```cpp
template<typename T, size_t N>
struct MemRefDescriptor {
T *allocated;
T *aligned;
intptr_t offset;
intptr_t sizes[N];
intptr_t stides[N];
};
```
If enabled, the option will do the following. For _external_ functions declared
in the MLIR module.
1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
as usual.
1. Add a body to the original function (making it non-external) that
1. allocates a memref descriptor,
1. populates it, and
1. passes the pointer to it into the newly declared interface function
1. collects the result of the call and returns it to the caller.
For (non-external) functions defined in the MLIR module.
1. Define a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
as usual.
1. Populate the body of the newly defined function with IR that
1. loads descriptors from pointers;
1. unpacks descriptor into individual non-aggregate values;
1. passes these values into the original function;
1. collects the result of the call and returns it to the caller.
Examples:
```mlir
func @qux(%arg0: memref<?x?xf32>)
// Gets converted into the following.
// Function with unpacked arguments.
llvm.func @qux(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
%arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
%arg6: !llvm.i64) {
// Populate memref descriptor (as per calling convention).
%0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// Store the descriptor in a stack-allocated space.
%8 = llvm.mlir.constant(1 : index) : !llvm.i64
%9 = llvm.alloca %8 x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
: (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
llvm.store %7, %9 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// Call the interface function.
llvm.call @_mlir_ciface_qux(%9) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
// The stored descriptor will be freed on return.
llvm.return
}
// Interface function.
llvm.func @_mlir_ciface_qux(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
```
```mlir
func @foo(%arg0: memref<?x?xf32>) {
return
}
// Gets converted into the following.
// Function with unpacked arguments.
llvm.func @foo(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
%arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
%arg6: !llvm.i64) {
llvm.return
}
// Interface function callable from C.
llvm.func @_mlir_ciface_foo(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
// Load the descriptor.
%0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// Unpack the descriptor as per calling convention.
%1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%5 = llvm.extractvalue %0[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%6 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%7 = llvm.extractvalue %0[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
: (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64,
!llvm.i64, !llvm.i64) -> ()
llvm.return
}
```
Rationale: Introducing auxiliary functions for C-compatible interfaces is
preferred to modifying the calling convention since it will minimize the effect
of C compatibility on intra-module calls or calls between MLIR-generated
functions. In particular, when calling external functions from an MLIR module in
a (parallel) loop, the fact of storing a memref descriptor on stack can lead to
stack exhaustion and/or concurrent access to the same address. Auxiliary
interface function serves as an allocation scope in this case. Furthermore, when
targeting accelerators with separate memory spaces such as GPUs, stack-allocated
descriptors passed by pointer would have to be transferred to the device memory,
which introduces significant overhead. In such situations, auxiliary interface
functions are executed on host and only pass the values through device function
invocation mechanism.
## Repeated Successor Removal
Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect

View File

@ -36,8 +36,8 @@ class LLVMType;
/// Set of callbacks that allows the customization of LLVMTypeConverter.
struct LLVMTypeConverterCustomization {
using CustomCallback =
std::function<LLVM::LLVMType(LLVMTypeConverter &, Type)>;
using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
SmallVectorImpl<Type> &)>;
/// Customize the type conversion of function arguments.
CustomCallback funcArgConverter;
@ -47,19 +47,26 @@ struct LLVMTypeConverterCustomization {
};
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a struct that contains the descriptor information. Converted
/// types are promoted to a pointer to the converted type.
LLVM::LLVMType structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type);
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type. Converted types are
/// not promoted to pointers.
LLVM::LLVMType barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type);
/// arguments to bare pointers to the MemRef element type.
LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result);
public:
using TypeConverter::convertType;
@ -107,6 +114,15 @@ public:
Value promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
/// Creates descriptor structs from individual values constituting them.
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
ArrayRef<Value> values,
Location loc) override;
protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
@ -133,14 +149,34 @@ private:
// by LLVM.
Type convertFloatType(FloatType type);
// Convert a memref type into an LLVM type that captures the relevant data.
// For statically-shaped memrefs, the resulting type is a pointer to the
// (converted) memref element type. For dynamically-shaped memrefs, the
// resulting type is an LLVM structure type that contains:
// 1. a pointer to the (converted) memref element type
// 2. as many index types as memref has dynamic dimensions.
/// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type);
/// Convert a memref type into a list of non-aggregate LLVM IR types that
/// contain all the relevant data. In particular, the list will contain:
/// - two pointers to the memref element type, followed by
/// - an integer offset, followed by
/// - one integer size per dimension of the memref, followed by
/// - one integer stride per dimension of the memref.
/// For example, memref<?x?xf32> is converted to the following list:
/// - `!llvm<"float*">` (allocated pointer),
/// - `!llvm<"float*">` (aligned pointer),
/// - `!llvm.i64` (offset),
/// - `!llvm.i64`, `!llvm.i64` (sizes),
/// - `!llvm.i64`, `!llvm.i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that contain all the relevant data. In particular, this list contains:
/// - an integer rank, followed by
/// - a pointer to the memref descriptor struct.
/// For example, memref<*xf32> is converted to the following list:
/// !llvm.i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
SmallVector<Type, 2> convertUnrankedMemRefSignature();
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
@ -180,6 +216,7 @@ protected:
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
@ -234,11 +271,63 @@ public:
/// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType();
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values);
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type, SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues(MemRefType type);
private:
// Cached index type.
Type indexType;
};
/// Helper class allowing the user to access a range of Values that correspond
/// to an unpacked memref descriptor using named accessors. This does not own
/// the values.
class MemRefDescriptorView {
public:
/// Constructs the view from a range of values. Infers the rank from the size
/// of the range.
explicit MemRefDescriptorView(ValueRange range);
/// Returns the allocated pointer Value.
Value allocatedPtr();
/// Returns the aligned pointer Value.
Value alignedPtr();
/// Returns the offset Value.
Value offset();
/// Returns the pos-th size Value.
Value size(unsigned pos);
/// Returns the pos-th stride Value.
Value stride(unsigned pos);
private:
/// Rank of the memref the descriptor is pointing to.
int rank;
/// Underlying range of Values.
ValueRange elements;
};
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
@ -255,6 +344,23 @@ public:
Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, UnrankedMemRefType type,
ValueRange values);
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the

View File

@ -29,16 +29,21 @@ void populateStdToLLVMMemoryConversionPatters(
void populateStdToLLVMNonMemoryConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
/// Collect the default pattern to convert a FuncOp to the LLVM dialect.
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
void populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool emitCWrappers = false);
/// Collect a set of default patterns to convert from the Standard dialect to
/// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will
/// generate `llvm.alloca` instead of calls to "malloc".
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns,
bool useAlloca = false);
bool useAlloca = false,
bool emitCWrappers = false);
/// Collect a set of patterns to convert from the Standard dialect to
/// LLVM using the bare pointer calling convention for MemRef function
@ -53,7 +58,7 @@ void populateStdToLLVMBarePtrConversionPatterns(
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
/// this may become an enum when we have concrete uses for other options.
std::unique_ptr<OpPassBase<ModuleOp>>
createLowerToLLVMPass(bool useAlloca = false);
createLowerToLLVMPass(bool useAlloca = false, bool emitCWrappers = false);
namespace LLVM {
/// Make argument-taking successors of each block distinct. PHI nodes in LLVM

View File

@ -30,6 +30,13 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
return ("arg" + Twine(arg)).toStringRef(out);
}
/// Returns true if the given name is a valid argument attribute name.
inline bool isArgAttrName(StringRef name) {
APInt unused;
return name.startswith("arg") &&
!name.drop_front(3).getAsInteger(/*Radix=*/10, unused);
}
/// Return the name of the attribute used for function results.
inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
out.clear();

View File

@ -113,6 +113,8 @@ private:
}
void declareCudaFunctions(Location loc);
void addParamToList(OpBuilder &builder, Location loc, Value param, Value list,
unsigned pos, Value one);
Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
Value generateKernelNameConstant(StringRef name, Location loc,
OpBuilder &builder);
@ -231,6 +233,35 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
}
}
/// Emits the IR with the following structure:
///
/// %data = llvm.alloca 1 x type-of(<param>)
/// llvm.store <param>, %data
/// %typeErased = llvm.bitcast %data to !llvm<"i8*">
/// %addr = llvm.getelementptr <list>[<pos>]
/// llvm.store %typeErased, %addr
///
/// This is necessary to construct the list of arguments passed to the kernel
/// function as accepted by cuLaunchKernel, i.e. as a void** that points to list
/// of stack-allocated type-erased pointers to the actual arguments.
void GpuLaunchFuncToCudaCallsPass::addParamToList(OpBuilder &builder,
Location loc, Value param,
Value list, unsigned pos,
Value one) {
auto memLocation = builder.create<LLVM::AllocaOp>(
loc, param.getType().cast<LLVM::LLVMType>().getPointerTo(), one,
/*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, param, memLocation);
auto casted =
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
auto index = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(pos));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), list,
ArrayRef<Value>{index});
builder.create<LLVM::StoreOp>(loc, casted, gep);
}
// Generates a parameters array to be used with a CUDA kernel launch call. The
// arguments are extracted from the launchOp.
// The generated code is essentially as follows:
@ -241,53 +272,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
// return %array
Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
OpBuilder &builder) {
// Get the launch target.
auto containingModule = launchOp.getParentOfType<ModuleOp>();
if (!containingModule)
return {};
auto gpuModule = containingModule.lookupSymbol<gpu::GPUModuleOp>(
launchOp.getKernelModuleName());
if (!gpuModule)
return {};
auto gpuFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>(launchOp.kernel());
if (!gpuFunc)
return {};
unsigned numArgs = gpuFunc.getNumArguments();
auto numKernelOperands = launchOp.getNumKernelOperands();
Location loc = launchOp.getLoc();
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(1));
// Provision twice as much for the `array` to allow up to one level of
// indirection for each argument.
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands));
loc, getInt32Type(), builder.getI32IntegerAttr(numArgs));
auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
arraySize, /*alignment=*/0);
unsigned pos = 0;
for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
auto operand = launchOp.getKernelOperand(idx);
auto llvmType = operand.getType().cast<LLVM::LLVMType>();
Value memLocation = builder.create<LLVM::AllocaOp>(
loc, llvmType.getPointerTo(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
auto casted =
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
// Assume all struct arguments come from MemRef. If this assumption does not
// hold anymore then we `launchOp` to lower from MemRefType and not after
// LLVMConversion has taken place and the MemRef information is lost.
// Extra level of indirection in the `array`:
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
if (llvmType.isStructTy()) {
auto registerFunc =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
ArrayRef<Value>{nullPtr, one});
auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
builder.getSymbolRefAttr(registerFunc),
ArrayRef<Value>{casted, size});
Value memLocation = builder.create<LLVM::AllocaOp>(
loc, getPointerPointerType(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, casted, memLocation);
casted =
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
if (!llvmType.isStructTy()) {
addParamToList(builder, loc, operand, array, pos++, one);
continue;
}
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(idx));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
ArrayRef<Value>{index});
builder.create<LLVM::StoreOp>(loc, casted, gep);
// Put individual components of a memref descriptor into the flat argument
// list. We cannot use unpackMemref from LLVM lowering here because we have
// no access to MemRefType that had been lowered away.
for (int32_t j = 0, ej = llvmType.getStructNumElements(); j < ej; ++j) {
auto elemType = llvmType.getStructElementType(j);
if (elemType.isArrayTy()) {
for (int32_t k = 0, ek = elemType.getArrayNumElements(); k < ek; ++k) {
Value elem = builder.create<LLVM::ExtractValueOp>(
loc, elemType.getArrayElementType(), operand,
builder.getI32ArrayAttr({j, k}));
addParamToList(builder, loc, elem, array, pos++, one);
}
} else {
assert((elemType.isIntegerTy() || elemType.isFloatTy() ||
elemType.isDoubleTy() || elemType.isPointerTy()) &&
"expected scalar type");
Value strct = builder.create<LLVM::ExtractValueOp>(
loc, elemType, operand, builder.getI32ArrayAttr(j));
addParamToList(builder, loc, strct, array, pos++, one);
}
}
}
return array;
}
@ -392,6 +436,10 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
if (!paramsArray) {
launchOp.emitOpError() << "cannot pass given parameters to the kernel";
return signalPassFailure();
}
auto nullpointer =
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
builder.create<LLVM::CallOp>(

View File

@ -564,8 +564,8 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
signatureConversion.addInputs(i, funcType.getFunctionParamType(i));
lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false,
signatureConversion);
// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
@ -651,25 +651,6 @@ struct GPUFuncOpLowering : LLVMOpLowering {
rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
signatureConversion);
{
// For memref-typed arguments, insert the relevant loads in the beginning
// of the block to comply with the LLVM dialect calling convention. This
// needs to be done after signature conversion to get the right types.
OpBuilder::InsertionGuard guard(rewriter);
Block &block = llvmFuncOp.front();
rewriter.setInsertionPointToStart(&block);
for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) {
if (!en.value().isa<MemRefType>() &&
!en.value().isa<UnrankedMemRefType>())
continue;
BlockArgument arg = block.getArgument(en.index());
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
}
}
rewriter.eraseOp(gpuFuncOp);
return matchSuccess();
}

View File

@ -577,7 +577,8 @@ void ConvertLinalgToLLVMPass::runOnModule() {
LinalgTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
/*emitCWrappers=*/true);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToStandardConversionPatterns(patterns, &getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());

View File

@ -30,6 +30,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
@ -43,6 +44,11 @@ static llvm::cl::opt<bool>
llvm::cl::desc("Replace emission of malloc/free by alloca"),
llvm::cl::init(false));
static llvm::cl::opt<bool>
clEmitCWrappers(PASS_NAME "-emit-c-wrappers",
llvm::cl::desc("Emit C-compatible wrapper functions"),
llvm::cl::init(false));
static llvm::cl::opt<bool> clUseBarePtrCallConv(
PASS_NAME "-use-bare-ptr-memref-call-conv",
llvm::cl::desc("Replace FuncOp's MemRef arguments with "
@ -66,18 +72,32 @@ LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() {
funcArgConverter = structFuncArgTypeConverter;
}
// Callback to convert function argument types. It converts a MemRef function
// arguments to a struct that contains the descriptor information. Converted
// types are promoted to a pointer to the converted type.
LLVM::LLVMType mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type) {
auto converted =
converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
if (auto memref = type.dyn_cast<MemRefType>()) {
auto converted = converter.convertMemRefSignature(memref);
if (converted.empty())
return failure();
result.append(converted.begin(), converted.end());
return success();
}
if (type.isa<UnrankedMemRefType>()) {
auto converted = converter.convertUnrankedMemRefSignature();
if (converted.empty())
return failure();
result.append(converted.begin(), converted.end());
return success();
}
auto converted = converter.convertType(type);
if (!converted)
return {};
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>())
converted = converted.getPointerTo();
return converted;
return failure();
result.push_back(converted);
return success();
}
/// Convert a MemRef type to a bare pointer to the MemRef element type.
@ -96,15 +116,26 @@ static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter,
}
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type. Converted types are
/// not promoted to pointers.
LLVM::LLVMType mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type) {
/// arguments to bare pointers to the MemRef element type.
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
// TODO: Add support for unranked memref.
if (auto memrefTy = type.dyn_cast<MemRefType>())
return convertMemRefTypeToBarePtr(converter, memrefTy)
.dyn_cast_or_null<LLVM::LLVMType>();
return converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy);
if (!llvmTy)
return failure();
result.push_back(llvmTy);
return success();
}
auto llvmTy = converter.convertType(type);
if (!llvmTy)
return failure();
result.push_back(llvmTy);
return success();
}
/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
@ -165,6 +196,33 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
return converted.getPointerTo();
}
/// In signatures, MemRef descriptors are expanded into lists of non-aggregate
/// values.
SmallVector<Type, 5>
LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
SmallVector<Type, 5> results;
assert(isStrided(type) &&
"Non-strided layout maps must have been normalized away");
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto indexTy = getIndexType();
results.insert(results.begin(), 2,
elementType.getPointerTo(type.getMemorySpace()));
results.push_back(indexTy);
auto rank = type.getRank();
results.insert(results.end(), 2 * rank, indexTy);
return results;
}
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
/// pointer to descriptor".
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
}
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
@ -175,9 +233,8 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs())) {
Type type = en.value();
auto converted = customizations.funcArgConverter(*this, type)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
SmallVector<Type, 8> converted;
if (failed(customizations.funcArgConverter(*this, type, converted)))
return {};
result.addInputs(en.index(), converted);
}
@ -199,6 +256,47 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
}
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
LLVM::LLVMType
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
SmallVector<LLVM::LLVMType, 4> inputs;
for (Type t : type.getInputs()) {
auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
return {};
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
converted = converted.getPointerTo();
inputs.push_back(converted);
}
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
}
/// Creates descriptor structs from individual values constituting them.
Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
Type type,
ArrayRef<Value> values,
Location loc) {
if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
unrankedMemRefType, values)
.getDefiningOp();
auto memRefType = type.dyn_cast<MemRefType>();
assert(memRefType && "1->N conversion is only supported for memrefs");
return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
.getDefiningOp();
}
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
@ -473,6 +571,85 @@ LLVM::LLVMType MemRefDescriptor::getElementType() {
kAlignedPtrPosInMemRefDescriptor);
}
/// Creates a MemRef descriptor structure from a list of individual values
/// composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
int64_t rank = type.getRank();
for (unsigned i = 0; i < rank; ++i) {
d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
}
return d;
}
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type,
SmallVectorImpl<Value> &results) {
int64_t rank = type.getRank();
results.reserve(results.size() + getNumUnpackedValues(type));
MemRefDescriptor d(packed);
results.push_back(d.allocatedPtr(builder, loc));
results.push_back(d.alignedPtr(builder, loc));
results.push_back(d.offset(builder, loc));
for (int64_t i = 0; i < rank; ++i)
results.push_back(d.size(builder, loc, i));
for (int64_t i = 0; i < rank; ++i)
results.push_back(d.stride(builder, loc, i));
}
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
// Two pointers, offset, <rank> sizes, <rank> shapes.
return 3 + 2 * type.getRank();
}
/*============================================================================*/
/* MemRefDescriptorView implementation. */
/*============================================================================*/
MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
: rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
Value MemRefDescriptorView::allocatedPtr() {
return elements[kAllocatedPtrPosInMemRefDescriptor];
}
Value MemRefDescriptorView::alignedPtr() {
return elements[kAlignedPtrPosInMemRefDescriptor];
}
Value MemRefDescriptorView::offset() {
return elements[kOffsetPosInMemRefDescriptor];
}
Value MemRefDescriptorView::size(unsigned pos) {
return elements[kSizePosInMemRefDescriptor + pos];
}
Value MemRefDescriptorView::stride(unsigned pos) {
return elements[kSizePosInMemRefDescriptor + rank + pos];
}
/*============================================================================*/
/* UnrankedMemRefDescriptor implementation */
/*============================================================================*/
@ -504,6 +681,34 @@ void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
Location loc, Value v) {
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
}
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter,
UnrankedMemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
return d;
}
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
Value packed,
SmallVectorImpl<Value> &results) {
UnrankedMemRefDescriptor d(packed);
results.reserve(results.size() + 2);
results.push_back(d.rank(builder, loc));
results.push_back(d.memRefDescPtr(builder, loc));
}
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
@ -551,9 +756,144 @@ protected:
LLVM::LLVMDialect &dialect;
};
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is("std.varargs") ||
(filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
continue;
result.push_back(attr);
}
}
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. This function can be called from C
/// by passing a pointer to a C struct corresponding to a memref descriptor.
/// Internally, the auxiliary function unpacks the descriptor into individual
/// components and forwards them to `newFuncOp`.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
auto type = funcOp.getType();
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
attributes);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
SmallVector<Value, 8> args;
for (auto &en : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(en.index());
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
continue;
}
if (en.value().isa<UnrankedMemRefType>()) {
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
continue;
}
args.push_back(wrapperFuncOp.getArgument(en.index()));
}
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
}
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. Creates a body for the (external)
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
/// individual arguments into this descriptor and passes a pointer to it into
/// the auxiliary function. This auxiliary external function is now compatible
/// with functions defined in C using pointers to C structs corresponding to a
/// memref descriptor.
static void wrapExternalFunction(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
LLVM::LLVMType wrapperType =
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
// This conversion can only fail if it could not convert one of the argument
// types. But since it has been applies to a non-wrapper function before, it
// should have failed earlier and not reach this point at all.
assert(wrapperType && "unexpected type conversion failure");
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
// Create the auxiliary function.
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperType, LLVM::Linkage::External, attributes);
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
// Get a ValueRange containing argument types. Note that ValueRange is
// currently not constructible from a pair of iterators pointing to
// BlockArgument.
FunctionType type = funcOp.getType();
SmallVector<Value, 8> args;
args.reserve(type.getNumInputs());
auto wrapperArgIters = newFuncOp.getArguments();
SmallVector<Value, 8> wrapperArgs(wrapperArgIters.begin(),
wrapperArgIters.end());
ValueRange wrapperArgsRange(wrapperArgs);
// Iterate over the inputs of the original function and pack values into
// memref descriptors if the original type is a memref.
for (auto &en : llvm::enumerate(type.getInputs())) {
Value arg;
int numToDrop = 1;
auto memRefType = en.value().dyn_cast<MemRefType>();
auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
if (memRefType || unrankedMemRefType) {
numToDrop = memRefType
? MemRefDescriptor::getNumUnpackedValues(memRefType)
: UnrankedMemRefDescriptor::getNumUnpackedValues();
Value packed =
memRefType
? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
wrapperArgsRange.take_front(numToDrop))
: UnrankedMemRefDescriptor::pack(
builder, loc, typeConverter, unrankedMemRefType,
wrapperArgsRange.take_front(numToDrop));
auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value allocated =
builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
builder.create<LLVM::StoreOp>(loc, packed, allocated);
arg = allocated;
} else {
arg = wrapperArgsRange[0];
}
args.push_back(arg);
wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
}
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}
struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
protected:
using LLVMLegalizationPattern::LLVMLegalizationPattern;
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
using UnsignedTypePair = std::pair<unsigned, Type>;
// Gather the positions and types of memref-typed arguments in a given
@ -579,14 +919,24 @@ protected:
auto llvmType = lowering.convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
// Only retain those attributes that are not constructed by build.
// Propagate argument attributes to all converted arguments obtained after
// converting a given original argument.
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : funcOp.getAttrs()) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is("std.varargs"))
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true,
attributes);
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto attr = impl::getArgAttrDict(funcOp, i);
if (!attr)
continue;
attributes.push_back(attr);
auto mapping = result.getInputMapping(i);
assert(mapping.hasValue() && "unexpected deletion of function argument");
SmallString<8> name;
for (size_t j = mapping->inputNo; j < mapping->size; ++j) {
impl::getArgAttrName(j, name);
attributes.push_back(rewriter.getNamedAttr(name, attr));
}
}
// Create an LLVM function, use external linkage by default until MLIR
@ -607,34 +957,33 @@ protected:
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
using FuncOpConversionBase::FuncOpConversionBase;
FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter,
bool emitCWrappers)
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
// Store the positions of memref-typed arguments so that we can emit loads
// from them to follow the calling convention.
SmallVector<UnsignedTypePair, 4> promotedArgsInfo;
getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
// Insert loads from memref descriptor pointers in function bodies.
if (!newFuncOp.getBody().empty()) {
Block *firstBlock = &newFuncOp.getBody().front();
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
for (const auto &argInfo : promotedArgsInfo) {
BlockArgument arg = firstBlock->getArgument(argInfo.first);
Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
}
if (emitWrappers) {
if (newFuncOp.isExternal())
wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp,
newFuncOp);
else
wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp,
newFuncOp);
}
rewriter.eraseOp(op);
return matchSuccess();
}
private:
/// If true, also create the adaptor functions having signatures compatible
/// with those produced by clang.
const bool emitWrappers;
};
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
@ -2273,14 +2622,17 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
}
void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter);
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool emitCWrappers) {
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter,
emitCWrappers);
}
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool useAlloca) {
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns);
bool useAlloca, bool emitCWrappers) {
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
emitCWrappers);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
}
@ -2346,13 +2698,20 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
for (auto it : llvm::zip(opOperands, operands)) {
auto operand = std::get<0>(it);
auto llvmOperand = std::get<1>(it);
if (!operand.getType().isa<MemRefType>() &&
!operand.getType().isa<UnrankedMemRefType>()) {
promotedOperands.push_back(operand);
if (operand.getType().isa<UnrankedMemRefType>()) {
UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
promotedOperands);
continue;
}
promotedOperands.push_back(
promoteOneMemRefDescriptor(loc, llvmOperand, builder));
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
MemRefDescriptor::unpack(builder, loc, llvmOperand,
operand.getType().cast<MemRefType>(),
promotedOperands);
continue;
}
promotedOperands.push_back(operand);
}
return promotedOperands;
}
@ -2362,11 +2721,21 @@ namespace {
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
/// Creates an LLVM lowering pass.
explicit LLVMLoweringPass(bool useAlloca = false,
bool useBarePtrCallConv = false)
: useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv) {}
bool useBarePtrCallConv = false,
bool emitCWrappers = false)
: useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv),
emitCWrappers(emitCWrappers) {}
/// Run the dialect converter on the module.
void runOnModule() override {
if (useBarePtrCallConv && emitCWrappers) {
getModule().emitError()
<< "incompatible conversion options: bare-pointer calling convention "
"and C wrapper emission";
signalPassFailure();
return;
}
ModuleOp m = getModule();
LLVM::ensureDistinctSuccessors(m);
@ -2380,7 +2749,8 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
useAlloca);
else
populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca);
populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca,
emitCWrappers);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
@ -2393,19 +2763,23 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
/// Convert memrefs to bare pointers in function signatures.
bool useBarePtrCallConv;
/// Emit wrappers for C-compatible pointer-to-struct memref descriptors.
bool emitCWrappers;
};
} // end namespace
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createLowerToLLVMPass(bool useAlloca) {
return std::make_unique<LLVMLoweringPass>(useAlloca);
mlir::createLowerToLLVMPass(bool useAlloca, bool emitCWrappers) {
return std::make_unique<LLVMLoweringPass>(useAlloca, emitCWrappers);
}
static PassRegistration<LLVMLoweringPass>
pass("convert-std-to-llvm",
pass(PASS_NAME,
"Convert scalar and vector operations from the "
"Standard to the LLVM dialect",
[] {
return std::make_unique<LLVMLoweringPass>(
clUseAlloca.getValue(), clUseBarePtrCallConv.getValue());
clUseAlloca.getValue(), clUseBarePtrCallConv.getValue(),
clEmitCWrappers.getValue());
});

View File

@ -90,26 +90,27 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return launchOp.emitOpError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
// TODO(ntv,zinenko,herhut): if the kernel function has been converted to
// the LLVM dialect but the caller hasn't (which happens during the
// separate compilation), do not check type correspondance as it would
// require the verifier to be aware of the LLVM type conversion.
if (kernelLLVMFunction)
return success();
unsigned actualNumArguments = launchOp.getNumKernelOperands();
unsigned expectedNumArguments = kernelLLVMFunction
? kernelLLVMFunction.getNumArguments()
: kernelGPUFunction.getNumArguments();
unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
if (expectedNumArguments != actualNumArguments)
return launchOp.emitOpError("got ")
<< actualNumArguments << " kernel operands but expected "
<< expectedNumArguments;
// Due to the ordering of the current impl of lowering and LLVMLowering,
// type checks need to be temporarily disabled.
// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
// to encode target module" has landed.
// auto functionType = kernelFunc.getType();
// for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
// if (getKernelOperand(i).getType() != functionType.getInput(i)) {
// return emitOpError("type of function argument ")
// << i << " does not match";
// }
// }
auto functionType = kernelGPUFunction.getType();
for (unsigned i = 0; i < expectedNumArguments; ++i) {
if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
return launchOp.emitOpError("type of function argument ")
<< i << " does not match";
}
}
return success();
});

View File

@ -401,8 +401,7 @@ Block *ArgConverter::applySignatureConversion(
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
Operation *cast = typeConverter->materializeConversion(
rewriter, origArg.getType(), replArgs, loc);
assert(cast->getNumResults() == 1 &&
cast->getNumOperands() == replArgs.size());
assert(cast->getNumResults() == 1);
mapping.map(origArg, cast->getResult(0));
info.argInfo[i] =
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));

View File

@ -6,8 +6,8 @@ module attributes {gpu.container_module} {
// CHECK: llvm.mlir.global internal constant @[[global:.*]]("CUBIN")
gpu.module @kernel_module attributes {nvvm.cubin = "CUBIN"} {
gpu.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
gpu.return
llvm.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
llvm.return
}
}

View File

@ -1,25 +1,11 @@
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> {dialect.a = true, dialect.b = 4 : i64}) {
// CHECK-NEXT: llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-LABEL: func @check_attributes
// When expanding the memref to multiple arguments, argument attributes are replicated.
// CHECK-COUNT-7: {dialect.a = true, dialect.b = 4 : i64}
func @check_attributes(%static: memref<10x20xf32> {dialect.a = true, dialect.b = 4 : i64 }) {
%c0 = constant 0 : index
%0 = load %static[%c0, %c0]: memref<10x20xf32>
return
}
// CHECK-LABEL: func @external_func(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
// CHECK: func @call_external(%[[arg:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
// CHECK: %[[ld:.*]] = llvm.load %[[arg]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[alloca:.*]] = llvm.alloca %[[c1]] x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK: llvm.store %[[ld]], %[[alloca]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK: call @external_func(%[[alloca]]) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
func @external_func(memref<10x20xf32>)
func @call_external(%static: memref<10x20xf32>) {
call @external_func(%static) : (memref<10x20xf32>) -> ()
return
}

View File

@ -1,14 +1,25 @@
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
// CHECK-LABEL: func @check_strided_memref_arguments(
// CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-LABEL: func @check_strided_memref_arguments(
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>,
%dynamic : memref<?x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>,
%mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) {
return
}
// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
// CHECK-LABEL: func @check_arguments
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) {
return
}
@ -16,7 +27,7 @@ func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %m
// CHECK-LABEL: func @mixed_alloc(
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> {
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
@ -45,10 +56,9 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
return %0 : memref<?x42x?xf32>
}
// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) {
// CHECK-LABEL: func @mixed_dealloc
func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x42x?xf32>
@ -59,7 +69,7 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-LABEL: func @dynamic_alloc(
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
// CHECK: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
@ -83,10 +93,9 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
return %0 : memref<?x?xf32>
}
// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
// CHECK-LABEL: func @dynamic_dealloc
func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
dealloc %arg0 : memref<?x?xf32>
@ -94,10 +103,12 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
}
// CHECK-LABEL: func @mixed_load(
// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">,
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
// CHECK: %[[I:.*]]: !llvm.i64,
// CHECK: %[[J:.*]]: !llvm.i64)
func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@ -112,10 +123,8 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
}
// CHECK-LABEL: func @dynamic_load(
// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@ -131,8 +140,7 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @prefetch
func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@ -161,8 +169,7 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @dynamic_store
func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@ -171,15 +178,14 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
store %val, %dynamic[%i, %j] : memref<?x?xf32>
return
}
// CHECK-LABEL: func @mixed_store
func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
@ -188,74 +194,66 @@ func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32)
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
store %val, %mixed[%i, %j] : memref<42x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_static_to_dynamic
func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_static_to_mixed
func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
return
}
// CHECK-LABEL: func @memref_cast_dynamic_to_static
func @memref_cast_dynamic_to_static(%dynamic : memref<?x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
return
}
// CHECK-LABEL: func @memref_cast_dynamic_to_mixed
func @memref_cast_dynamic_to_mixed(%dynamic : memref<?x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_dynamic
func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_static
func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
return
}
// CHECK-LABEL: func @memref_cast_mixed_to_mixed
func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
return
}
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
// CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
@ -266,19 +264,17 @@ func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
// CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i8* }">
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
return
}
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
// CHECK-LABEL: func @mixed_memref_dim
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
%0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK-NEXT: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
%1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
%2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>

View File

@ -18,12 +18,12 @@ func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
//CHECK: llvm.func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))
// Check that memrefs are converted to pointers-to-struct if appear as function arguments.
// CHECK: llvm.func @memref_call_conv(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
// Check that memrefs are converted to argument packs if appear as function arguments.
// CHECK: llvm.func @memref_call_conv(!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
func @memref_call_conv(%arg0: memref<?xf32>)
// Same in nested functions.
// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void ({ float*, float*, i64, [1 x i64], [1 x i64] }*)*">)
// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void (float*, float*, i64, i64, i64)*">)
func @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {

View File

@ -10,7 +10,10 @@ func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
// -----
// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
// CHECK-LABEL: func @check_static_return
// CHECK-COUNT-2: !llvm<"float*">
// CHECK-COUNT-5: !llvm.i64
// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-LABEL: func @check_static_return
// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
@ -76,11 +79,10 @@ func @zero_d_alloc() -> memref<f32> {
// -----
// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
// CHECK-LABEL: func @zero_d_dealloc
// BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) {
func @zero_d_dealloc(%arg0: memref<f32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
@ -96,7 +98,7 @@ func @zero_d_dealloc(%arg0: memref<f32>) {
// CHECK-LABEL: func @aligned_1d_alloc(
// BAREPTR-LABEL: func @aligned_1d_alloc(
func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
@ -150,13 +152,13 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
// BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
func @static_alloc() -> memref<32x18xf32> {
// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
@ -177,11 +179,10 @@ func @static_alloc() -> memref<32x18xf32> {
// -----
// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
// CHECK-LABEL: func @static_dealloc
// BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) {
func @static_dealloc(%static: memref<10x8xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
@ -194,11 +195,10 @@ func @static_dealloc(%static: memref<10x8xf32>) {
// -----
// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float {
// CHECK-LABEL: func @zero_d_load
// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float
func @zero_d_load(%arg0: memref<f32>) -> f32 {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*">
@ -214,20 +214,22 @@ func @zero_d_load(%arg0: memref<f32>) -> f32 {
// -----
// CHECK-LABEL: func @static_load(
// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">,
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
// CHECK: %[[I:.*]]: !llvm.i64,
// CHECK: %[[J:.*]]: !llvm.i64)
// BAREPTR-LABEL: func @static_load
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) {
func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@ -246,15 +248,14 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// -----
// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) {
// CHECK-LABEL: func @zero_d_store
// BAREPTR-LABEL: func @zero_d_store
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float)
func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*">
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@ -270,17 +271,16 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
// BAREPTR-LABEL: func @static_store
// BAREPTR-SAME: %[[A:.*]]: !llvm<"float*">
func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@ -298,11 +298,10 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
// -----
// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
// CHECK-LABEL: func @static_memref_dim
// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
%0 = dim %static, 0 : memref<42x32x15x13x27xf32>

View File

@ -728,9 +728,15 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
}
// CHECK-LABEL: func @subview(
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
// CHECK-COUNT-2: !llvm<"float*">,
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64,
// CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64,
// CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64,
// CHECK: %[[ARG2:.*]]: !llvm.i64)
func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@ -754,9 +760,10 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg
}
// CHECK-LABEL: func @subview_const_size(
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@ -782,9 +789,10 @@ func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 +
}
// CHECK-LABEL: func @subview_const_stride(
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">

View File

@ -1,8 +1,7 @@
// RUN: mlir-opt %s -convert-std-to-llvm -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @address_space(
// CHECK: %{{.*}}: !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">)
// CHECK: llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">
// CHECK-SAME: !llvm<"float addrspace(7)*">
func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
%0 = alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5>
%1 = constant 7 : index

View File

@ -175,24 +175,22 @@ module attributes {gpu.container_module} {
// -----
gpu.module @kernels {
gpu.func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
gpu.return
module attributes {gpu.container_module} {
gpu.module @kernels {
gpu.func @kernel_1(%arg1 : f32) attributes { gpu.kernel } {
gpu.return
}
}
func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
// expected-err@+1 {{type of function argument 0 does not match}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
{kernel = "kernel_1", kernel_module = @kernels}
: (index, index, index, index, index, index, f32) -> ()
return
}
}
// Due to the ordering of the current impl of lowering and LLVMLowering, type
// checks need to be temporarily disabled.
// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
// to encode target module" has landed.
// func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
// // expected-err@+1 {{type of function argument 0 does not match}}
// "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
// {kernel = "kernel_1"}
// : (index, index, index, index, index, index, f32) -> ()
// return
// }
// -----
func @illegal_dimension() {

View File

@ -52,9 +52,11 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
return
}
// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64 }*">) -> ()
// CHECK-LABEL: func @dot
// CHECK: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}) :
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64
func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?, 1]>) {
%c0 = constant 0 : index
@ -83,7 +85,9 @@ func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memre
return
}
// CHECK-LABEL: func @copy
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32({{.*}}) :
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
@ -128,9 +132,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// Call external copy after promoting input and output structs to pointers
// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
// Call external copy.
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32
#matmul_accesses = [
affine_map<(m, n, k) -> (m, k)>,
@ -163,7 +166,10 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
return
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) :
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
@ -195,7 +201,10 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
}
// CHECK-LABEL: func @matmul_vec_indexed(
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}) :
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
func @reshape_static(%arg0: memref<3x4x5xf32>) {
// Reshapes that expand and collapse back a contiguous tensor with some 1's.

View File

@ -15,32 +15,35 @@
#include <assert.h>
#include <iostream>
extern "C" void linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X,
float f) {
extern "C" void
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
X->data[X->offset] = f;
}
extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
float f) {
extern "C" void
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
*(X->data + X->offset + i * X->strides[0]) = f;
}
extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
float f) {
extern "C" void
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
for (unsigned j = 0; j < X->sizes[1]; ++j)
*(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
}
extern "C" void linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
StridedMemRefType<float, 0> *O) {
extern "C" void
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
StridedMemRefType<float, 0> *O) {
O->data[O->offset] = I->data[I->offset];
}
extern "C" void
linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
StridedMemRefType<float, 1> *O) {
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
StridedMemRefType<float, 1> *O) {
if (I->sizes[0] != O->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
@ -52,9 +55,8 @@ linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
I->data[I->offset + i * I->strides[0]];
}
extern "C" void
linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
StridedMemRefType<float, 2> *O) {
extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
@ -69,10 +71,9 @@ linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
I->data[I->offset + i * si0 + j * si1];
}
extern "C" void
linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
StridedMemRefType<float, 1> *Y,
StridedMemRefType<float, 0> *Z) {
extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
StridedMemRefType<float, 0> *Z) {
if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *X);
@ -85,7 +86,7 @@ linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
Y->data + Y->offset, Y->strides[0]);
}
extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C) {
if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||

View File

@ -25,33 +25,34 @@
#endif // _WIN32
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X, float f);
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
float f);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
StridedMemRefType<float, 0> *O);
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
StridedMemRefType<float, 0> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
StridedMemRefType<float, 1> *O);
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
StridedMemRefType<float, 1> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
StridedMemRefType<float, 2> *O);
_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
StridedMemRefType<float, 1> *Y,
StridedMemRefType<float, 0> *Z);
_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
StridedMemRefType<float, 0> *Z);
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C);

View File

@ -261,23 +261,27 @@ template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
// Currently exposed C API.
////////////////////////////////////////////////////////////////////////////////
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_i8(UnrankedMemRefType<int8_t> *M);
_mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_f32(UnrankedMemRefType<float> *M);
_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(int64_t rank,
void *ptr);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_1d_f32(StridedMemRefType<float, 1> *M);
_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_2d_f32(StridedMemRefType<float, 2> *M);
_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_3d_f32(StridedMemRefType<float, 3> *M);
_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_4d_f32(StridedMemRefType<float, 4> *M);
_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
_mlir_ciface_print_memref_vector_4x4xf32(
StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
// Small runtime support "lib" for vector.print lowering.
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);

View File

@ -16,8 +16,8 @@
#include <cinttypes>
#include <cstdio>
extern "C" void
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
impl::printMemRef(*M);
}
@ -26,7 +26,7 @@ print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr))); \
break
extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
printUnrankedMemRefMetaData(std::cout, *M);
int rank = M->rank;
void *ptr = M->descriptor;
@ -42,7 +42,7 @@ extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
}
}
extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
printUnrankedMemRefMetaData(std::cout, *M);
int rank = M->rank;
void *ptr = M->descriptor;
@ -58,19 +58,31 @@ extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
}
}
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
extern "C" void print_memref_f32(int64_t rank, void *ptr) {
UnrankedMemRefType<float> descriptor;
descriptor.rank = rank;
descriptor.descriptor = ptr;
_mlir_ciface_print_memref_f32(&descriptor);
}
extern "C" void
_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
impl::printMemRef(*M);
}
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
extern "C" void
_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
impl::printMemRef(*M);
}
extern "C" void print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
extern "C" void
_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
impl::printMemRef(*M);
}
extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
extern "C" void
_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
impl::printMemRef(*M);
}
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
extern "C" void
_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
impl::printMemRef(*M);
}

View File

@ -17,12 +17,13 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
return
}
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
func @print_memref_1d_f32(memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@ -96,11 +96,34 @@ void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
std::fill_n(arg->data, count, value);
mcuMemHostRegister(arg->data, count * sizeof(T));
}
extern "C" void
mcuMemHostRegisterMemRef1dFloat(const MemRefType<float, 1> *arg) {
mcuMemHostRegisterMemRef(arg, 1.23f);
extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
float *aligned, int64_t offset,
int64_t size, int64_t stride) {
MemRefType<float, 1> descriptor;
descriptor.basePtr = allocated;
descriptor.data = aligned;
descriptor.offset = offset;
descriptor.sizes[0] = size;
descriptor.strides[0] = stride;
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
}
extern "C" void
mcuMemHostRegisterMemRef3dFloat(const MemRefType<float, 3> *arg) {
mcuMemHostRegisterMemRef(arg, 1.23f);
extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
float *aligned, int64_t offset,
int64_t size0, int64_t size1,
int64_t size2, int64_t stride0,
int64_t stride1,
int64_t stride2) {
MemRefType<float, 3> descriptor;
descriptor.basePtr = allocated;
descriptor.data = aligned;
descriptor.offset = offset;
descriptor.sizes[0] = size0;
descriptor.strides[0] = stride0;
descriptor.sizes[1] = size1;
descriptor.strides[1] = stride1;
descriptor.sizes[2] = size2;
descriptor.strides[2] = stride2;
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
}