forked from OSchip/llvm-project
[mlir] use unpacked memref descriptors at function boundaries
The existing (default) calling convention for memrefs in standard-to-LLVM conversion was motivated by interfacing with LLVM IR produced from C sources. In particular, it passes a pointer to the memref descriptor structure when calling the function. Therefore, the descriptor is allocated on stack before the call. This convention leads to several problems. PR44644 indicates a problem with stack exhaustion when calling functions with memref-typed arguments in a loop. Allocating outside of the loop may lead to concurrent access problems in case the loop is parallel. When targeting GPUs, the contents of the stack-allocated memory for the descriptor (passed by pointer) needs to be explicitly copied to the device. Using an aggregate type makes it impossible to attach pointer-specific argument attributes pertaining to alignment and aliasing in the LLVM dialect. Change the default calling convention for memrefs in standard-to-LLVM conversion to transform a memref into a list of arguments, each of primitive type, that are comprised in the memref descriptor. This avoids stack allocation for ranked memrefs (and thus stack exhaustion and potential concurrent access problems) and simplifies the device function invocation on GPUs. Provide an option in the standard-to-LLVM conversion to generate auxiliary wrapper function with the same interface as the previous calling convention, compatible with LLVM IR porduced from C sources. These auxiliary functions pack the individual values into a descriptor structure or unpack it. They also handle descriptor stack allocation if necessary, serving as an allocation scope: the memory reserved by `alloca` will be freed on exiting the auxiliary function. The effect of this change on MLIR-generated only LLVM IR is minimal. When interfacing MLIR-generated LLVM IR with C-generated LLVM IR, the integration only needs to require auxiliary functions and change the function name to call the wrapper function instead of the original function. This also opens the door to forwarding aliasing and alignment information from memrefs to LLVM IR pointers in the standrd-to-LLVM conversion.
This commit is contained in:
parent
1dc62d0358
commit
5a1778057f
|
@ -248,58 +248,123 @@ func @bar() {
|
|||
|
||||
### Calling Convention for `memref`
|
||||
|
||||
For function _arguments_ of `memref` type, ranked or unranked, the type of the
|
||||
argument is a _pointer_ to the memref descriptor type defined above. The caller
|
||||
of such function is required to store the descriptor in memory and guarantee
|
||||
that the storage remains live until the callee returns. The caller can than pass
|
||||
the pointer to that memory as function argument. The callee loads from the
|
||||
pointers it was passed as arguments in the entry block of the function, making
|
||||
the descriptor passed in as argument available for use similarly to
|
||||
ocally-defined descriptors.
|
||||
Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a
|
||||
list of arguments of non-aggregate types that the memref descriptor defined
|
||||
above comprises. That is, the outer struct type and the inner array types are
|
||||
replaced with individual arguments.
|
||||
|
||||
This convention is implemented in the conversion of `std.func` and `std.call` to
|
||||
the LLVM dialect. Conversions from other dialects should take it into account.
|
||||
The motivation for this convention is to simplify the ABI for interfacing with
|
||||
other LLVM modules, in particular those generated from C sources, while avoiding
|
||||
platform-specific aspects until MLIR has a proper ABI modeling.
|
||||
the LLVM dialect, with the former unpacking the descriptor into a set of
|
||||
individual values and the latter packing those values back into a descriptor so
|
||||
as to make it transparently usable by other operations. Conversions from other
|
||||
dialects should take this convention into account.
|
||||
|
||||
Example:
|
||||
This specific convention is motivated by the necessity to specify alignment and
|
||||
aliasing attributes on the raw pointers underpinning the memref.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
|
||||
func @foo(memref<?xf32>) -> () {
|
||||
%c0 = constant 0 : index
|
||||
load %arg0[%c0] : memref<?xf32>
|
||||
func @foo(%arg0: memref<?xf32>) -> () {
|
||||
"use"(%arg0) : (memref<?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @bar(%arg0: index) {
|
||||
%0 = alloc(%arg0) : memref<?xf32>
|
||||
call @foo(%0) : (memref<?xf32>)-> ()
|
||||
// Gets converted to the following.
|
||||
|
||||
llvm.func @foo(%arg0: !llvm<"float*">, // Allocated pointer.
|
||||
%arg1: !llvm<"float*">, // Aligned pointer.
|
||||
%arg2: !llvm.i64, // Offset.
|
||||
%arg3: !llvm.i64, // Size in dim 0.
|
||||
%arg4: !llvm.i64) { // Stride in dim 0.
|
||||
// Populate memref descriptor structure.
|
||||
%0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
// Descriptor is now usable as a single value.
|
||||
"use"(%5) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">) -> ()
|
||||
llvm.return
|
||||
}
|
||||
```
|
||||
|
||||
```mlir
|
||||
func @bar() {
|
||||
%0 = "get"() : () -> (memref<?xf32>)
|
||||
call @foo(%0) : (memref<?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Gets converted to the following IR.
|
||||
// Accepts a pointer to the memref descriptor.
|
||||
llvm.func @foo(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) {
|
||||
// Loads the descriptor so that it can be used similarly to locally
|
||||
// created descriptors.
|
||||
%0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
|
||||
// Gets converted to the following.
|
||||
|
||||
llvm.func @bar() {
|
||||
%0 = "get"() : () -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
// Unpack the memref descriptor.
|
||||
%1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
%5 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
// Pass individual values to the callee.
|
||||
llvm.call @foo(%1, %2, %3, %4, %5) : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @bar(%arg0: !llvm.i64) {
|
||||
// ... Allocation ...
|
||||
// Definition of the descriptor.
|
||||
%7 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
// ... Filling in the descriptor ...
|
||||
%14 = // The final value of the allocated descriptor.
|
||||
// Allocate the memory for the descriptor and store it.
|
||||
%15 = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
%16 = llvm.alloca %15 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
: (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
|
||||
llvm.store %14, %16 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">
|
||||
// Pass the pointer to the function.
|
||||
llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
|
||||
```
|
||||
|
||||
For **unranked** memrefs, the list of function arguments always contains two
|
||||
elements, same as the unranked memref descriptor: an integer rank, and a
|
||||
type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that
|
||||
while the _calling convention_ does not require stack allocation, _casting_ to
|
||||
unranked memref does since one cannot take an address of an SSA value containing
|
||||
the ranked memref. The caller is in charge of ensuring the thread safety and
|
||||
eventually removing unnecessary stack allocations in cast operations.
|
||||
|
||||
Example
|
||||
|
||||
```mlir
|
||||
llvm.func @foo(%arg0: memref<*xf32>) -> () {
|
||||
"use"(%arg0) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Gets converted to the following.
|
||||
|
||||
llvm.func @foo(%arg0: !llvm.i64 // Rank.
|
||||
%arg1: !llvm<"i8*">) { // Type-erased pointer to descriptor.
|
||||
// Pack the unranked memref descriptor.
|
||||
%0 = llvm.mlir.undef : !llvm<"{ i64, i8* }">
|
||||
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i8* }">
|
||||
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i8* }">
|
||||
|
||||
"use"(%2) : (!llvm<"{ i64, i8* }">) -> ()
|
||||
llvm.return
|
||||
}
|
||||
```
|
||||
|
||||
```mlir
|
||||
llvm.func @bar() {
|
||||
%0 = "get"() : () -> (memref<*xf32>)
|
||||
call @foo(%0): (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Gets converted to the following.
|
||||
|
||||
llvm.func @bar() {
|
||||
%0 = "get"() : () -> (!llvm<"{ i64, i8* }">)
|
||||
|
||||
// Unpack the memref descriptor.
|
||||
%1 = llvm.extractvalue %0[0] : !llvm<"{ i64, i8* }">
|
||||
%2 = llvm.extractvalue %0[1] : !llvm<"{ i64, i8* }">
|
||||
|
||||
// Pass individual values to the callee.
|
||||
llvm.call @foo(%1, %2) : (!llvm.i64, !llvm<"i8*">)
|
||||
llvm.return
|
||||
}
|
||||
```
|
||||
|
@ -307,6 +372,141 @@ llvm.func @bar(%arg0: !llvm.i64) {
|
|||
*This convention may or may not apply if the conversion of MemRef types is
|
||||
overridden by the user.*
|
||||
|
||||
### C-compatible wrapper emission
|
||||
|
||||
In practical cases, it may be desirable to have externally-facing functions
|
||||
with a single attribute corresponding to a MemRef argument. When interfacing
|
||||
with LLVM IR produced from C, the code needs to respect the corresponding
|
||||
calling convention. The conversion to the LLVM dialect provides an option to
|
||||
generate wrapper functions that take memref descriptors as pointers-to-struct
|
||||
compatible with data types produced by Clang when compiling C sources.
|
||||
|
||||
More specifically, a memref argument is converted into a pointer-to-struct
|
||||
argument of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function, where
|
||||
`T` is the converted element type and `N` is the memref rank. This type is
|
||||
compatible with that produced by Clang for the following C++ structure template
|
||||
instantiations or their equivalents in C.
|
||||
|
||||
```cpp
|
||||
template<typename T, size_t N>
|
||||
struct MemRefDescriptor {
|
||||
T *allocated;
|
||||
T *aligned;
|
||||
intptr_t offset;
|
||||
intptr_t sizes[N];
|
||||
intptr_t stides[N];
|
||||
};
|
||||
```
|
||||
|
||||
If enabled, the option will do the following. For _external_ functions declared
|
||||
in the MLIR module.
|
||||
|
||||
1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
|
||||
are converted to pointer-to-struct and the remaining arguments are converted
|
||||
as usual.
|
||||
1. Add a body to the original function (making it non-external) that
|
||||
1. allocates a memref descriptor,
|
||||
1. populates it, and
|
||||
1. passes the pointer to it into the newly declared interface function
|
||||
1. collects the result of the call and returns it to the caller.
|
||||
|
||||
For (non-external) functions defined in the MLIR module.
|
||||
|
||||
1. Define a new function `_mlir_ciface_<original name>` where memref arguments
|
||||
are converted to pointer-to-struct and the remaining arguments are converted
|
||||
as usual.
|
||||
1. Populate the body of the newly defined function with IR that
|
||||
1. loads descriptors from pointers;
|
||||
1. unpacks descriptor into individual non-aggregate values;
|
||||
1. passes these values into the original function;
|
||||
1. collects the result of the call and returns it to the caller.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
|
||||
func @qux(%arg0: memref<?x?xf32>)
|
||||
|
||||
// Gets converted into the following.
|
||||
|
||||
// Function with unpacked arguments.
|
||||
llvm.func @qux(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
|
||||
%arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
|
||||
%arg6: !llvm.i64) {
|
||||
// Populate memref descriptor (as per calling convention).
|
||||
%0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
||||
// Store the descriptor in a stack-allocated space.
|
||||
%8 = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
%9 = llvm.alloca %8 x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
: (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
llvm.store %7, %9 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
|
||||
// Call the interface function.
|
||||
llvm.call @_mlir_ciface_qux(%9) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
|
||||
// The stored descriptor will be freed on return.
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// Interface function.
|
||||
llvm.func @_mlir_ciface_qux(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
|
||||
```
|
||||
|
||||
```mlir
|
||||
func @foo(%arg0: memref<?x?xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Gets converted into the following.
|
||||
|
||||
// Function with unpacked arguments.
|
||||
llvm.func @foo(%arg0: !llvm<"float*">, %arg1: !llvm<"float*">, %arg2: !llvm.i64,
|
||||
%arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64,
|
||||
%arg6: !llvm.i64) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// Interface function callable from C.
|
||||
llvm.func @_mlir_ciface_foo(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
|
||||
// Load the descriptor.
|
||||
%0 = llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
|
||||
// Unpack the descriptor as per calling convention.
|
||||
%1 = llvm.extractvalue %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%2 = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%3 = llvm.extractvalue %0[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%4 = llvm.extractvalue %0[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%5 = llvm.extractvalue %0[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%6 = llvm.extractvalue %0[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%7 = llvm.extractvalue %0[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
|
||||
: (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64,
|
||||
!llvm.i64, !llvm.i64) -> ()
|
||||
llvm.return
|
||||
}
|
||||
```
|
||||
|
||||
Rationale: Introducing auxiliary functions for C-compatible interfaces is
|
||||
preferred to modifying the calling convention since it will minimize the effect
|
||||
of C compatibility on intra-module calls or calls between MLIR-generated
|
||||
functions. In particular, when calling external functions from an MLIR module in
|
||||
a (parallel) loop, the fact of storing a memref descriptor on stack can lead to
|
||||
stack exhaustion and/or concurrent access to the same address. Auxiliary
|
||||
interface function serves as an allocation scope in this case. Furthermore, when
|
||||
targeting accelerators with separate memory spaces such as GPUs, stack-allocated
|
||||
descriptors passed by pointer would have to be transferred to the device memory,
|
||||
which introduces significant overhead. In such situations, auxiliary interface
|
||||
functions are executed on host and only pass the values through device function
|
||||
invocation mechanism.
|
||||
|
||||
## Repeated Successor Removal
|
||||
|
||||
Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect
|
||||
|
|
|
@ -36,8 +36,8 @@ class LLVMType;
|
|||
|
||||
/// Set of callbacks that allows the customization of LLVMTypeConverter.
|
||||
struct LLVMTypeConverterCustomization {
|
||||
using CustomCallback =
|
||||
std::function<LLVM::LLVMType(LLVMTypeConverter &, Type)>;
|
||||
using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
|
||||
SmallVectorImpl<Type> &)>;
|
||||
|
||||
/// Customize the type conversion of function arguments.
|
||||
CustomCallback funcArgConverter;
|
||||
|
@ -47,19 +47,26 @@ struct LLVMTypeConverterCustomization {
|
|||
};
|
||||
|
||||
/// Callback to convert function argument types. It converts a MemRef function
|
||||
/// argument to a struct that contains the descriptor information. Converted
|
||||
/// types are promoted to a pointer to the converted type.
|
||||
LLVM::LLVMType structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type);
|
||||
/// argument to a list of non-aggregate types containing descriptor
|
||||
/// information, and an UnrankedmemRef function argument to a list containing
|
||||
/// the rank and a pointer to a descriptor struct.
|
||||
LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type,
|
||||
SmallVectorImpl<Type> &result);
|
||||
|
||||
/// Callback to convert function argument types. It converts MemRef function
|
||||
/// arguments to bare pointers to the MemRef element type. Converted types are
|
||||
/// not promoted to pointers.
|
||||
LLVM::LLVMType barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type);
|
||||
/// arguments to bare pointers to the MemRef element type.
|
||||
LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type,
|
||||
SmallVectorImpl<Type> &result);
|
||||
|
||||
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
|
||||
class LLVMTypeConverter : public TypeConverter {
|
||||
/// Give structFuncArgTypeConverter access to memref-specific functions.
|
||||
friend LogicalResult
|
||||
structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
|
||||
SmallVectorImpl<Type> &result);
|
||||
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
|
||||
|
@ -107,6 +114,15 @@ public:
|
|||
Value promoteOneMemRefDescriptor(Location loc, Value operand,
|
||||
OpBuilder &builder);
|
||||
|
||||
/// Converts the function type to a C-compatible format, in particular using
|
||||
/// pointers to memref descriptors for arguments.
|
||||
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
|
||||
|
||||
/// Creates descriptor structs from individual values constituting them.
|
||||
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
|
||||
ArrayRef<Value> values,
|
||||
Location loc) override;
|
||||
|
||||
protected:
|
||||
/// LLVM IR module used to parse/create types.
|
||||
llvm::Module *module;
|
||||
|
@ -133,14 +149,34 @@ private:
|
|||
// by LLVM.
|
||||
Type convertFloatType(FloatType type);
|
||||
|
||||
// Convert a memref type into an LLVM type that captures the relevant data.
|
||||
// For statically-shaped memrefs, the resulting type is a pointer to the
|
||||
// (converted) memref element type. For dynamically-shaped memrefs, the
|
||||
// resulting type is an LLVM structure type that contains:
|
||||
// 1. a pointer to the (converted) memref element type
|
||||
// 2. as many index types as memref has dynamic dimensions.
|
||||
/// Convert a memref type into an LLVM type that captures the relevant data.
|
||||
Type convertMemRefType(MemRefType type);
|
||||
|
||||
/// Convert a memref type into a list of non-aggregate LLVM IR types that
|
||||
/// contain all the relevant data. In particular, the list will contain:
|
||||
/// - two pointers to the memref element type, followed by
|
||||
/// - an integer offset, followed by
|
||||
/// - one integer size per dimension of the memref, followed by
|
||||
/// - one integer stride per dimension of the memref.
|
||||
/// For example, memref<?x?xf32> is converted to the following list:
|
||||
/// - `!llvm<"float*">` (allocated pointer),
|
||||
/// - `!llvm<"float*">` (aligned pointer),
|
||||
/// - `!llvm.i64` (offset),
|
||||
/// - `!llvm.i64`, `!llvm.i64` (sizes),
|
||||
/// - `!llvm.i64`, `!llvm.i64` (strides).
|
||||
/// These types can be recomposed to a memref descriptor struct.
|
||||
SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
|
||||
|
||||
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
|
||||
/// that contain all the relevant data. In particular, this list contains:
|
||||
/// - an integer rank, followed by
|
||||
/// - a pointer to the memref descriptor struct.
|
||||
/// For example, memref<*xf32> is converted to the following list:
|
||||
/// !llvm.i64 (rank)
|
||||
/// !llvm<"i8*"> (type-erased pointer).
|
||||
/// These types can be recomposed to a unranked memref descriptor struct.
|
||||
SmallVector<Type, 2> convertUnrankedMemRefSignature();
|
||||
|
||||
// Convert an unranked memref type to an LLVM type that captures the
|
||||
// runtime rank and a pointer to the static ranked memref desc
|
||||
Type convertUnrankedMemRefType(UnrankedMemRefType type);
|
||||
|
@ -180,6 +216,7 @@ protected:
|
|||
/// Builds IR to set a value in the struct at position pos
|
||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
|
||||
};
|
||||
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||
/// The Value may be null, in which case none of the operations are valid.
|
||||
|
@ -234,11 +271,63 @@ public:
|
|||
/// Returns the (LLVM) type this descriptor points to.
|
||||
LLVM::LLVMType getElementType();
|
||||
|
||||
/// Builds IR populating a MemRef descriptor structure from a list of
|
||||
/// individual values composing that descriptor, in the following order:
|
||||
/// - allocated pointer;
|
||||
/// - aligned pointer;
|
||||
/// - offset;
|
||||
/// - <rank> sizes;
|
||||
/// - <rank> shapes;
|
||||
/// where <rank> is the MemRef rank as provided in `type`.
|
||||
static Value pack(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &converter, MemRefType type,
|
||||
ValueRange values);
|
||||
|
||||
/// Builds IR extracting individual elements of a MemRef descriptor structure
|
||||
/// and returning them as `results` list.
|
||||
static void unpack(OpBuilder &builder, Location loc, Value packed,
|
||||
MemRefType type, SmallVectorImpl<Value> &results);
|
||||
|
||||
/// Returns the number of non-aggregate values that would be produced by
|
||||
/// `unpack`.
|
||||
static unsigned getNumUnpackedValues(MemRefType type);
|
||||
|
||||
private:
|
||||
// Cached index type.
|
||||
Type indexType;
|
||||
};
|
||||
|
||||
/// Helper class allowing the user to access a range of Values that correspond
|
||||
/// to an unpacked memref descriptor using named accessors. This does not own
|
||||
/// the values.
|
||||
class MemRefDescriptorView {
|
||||
public:
|
||||
/// Constructs the view from a range of values. Infers the rank from the size
|
||||
/// of the range.
|
||||
explicit MemRefDescriptorView(ValueRange range);
|
||||
|
||||
/// Returns the allocated pointer Value.
|
||||
Value allocatedPtr();
|
||||
|
||||
/// Returns the aligned pointer Value.
|
||||
Value alignedPtr();
|
||||
|
||||
/// Returns the offset Value.
|
||||
Value offset();
|
||||
|
||||
/// Returns the pos-th size Value.
|
||||
Value size(unsigned pos);
|
||||
|
||||
/// Returns the pos-th stride Value.
|
||||
Value stride(unsigned pos);
|
||||
|
||||
private:
|
||||
/// Rank of the memref the descriptor is pointing to.
|
||||
int rank;
|
||||
/// Underlying range of Values.
|
||||
ValueRange elements;
|
||||
};
|
||||
|
||||
class UnrankedMemRefDescriptor : public StructBuilder {
|
||||
public:
|
||||
/// Construct a helper for the given descriptor value.
|
||||
|
@ -255,6 +344,23 @@ public:
|
|||
Value memRefDescPtr(OpBuilder &builder, Location loc);
|
||||
/// Builds IR setting ranked memref descriptor ptr
|
||||
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
|
||||
|
||||
/// Builds IR populating an unranked MemRef descriptor structure from a list
|
||||
/// of individual constituent values in the following order:
|
||||
/// - rank of the memref;
|
||||
/// - pointer to the memref descriptor.
|
||||
static Value pack(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &converter, UnrankedMemRefType type,
|
||||
ValueRange values);
|
||||
|
||||
/// Builds IR extracting individual elements that compose an unranked memref
|
||||
/// descriptor and returns them as `results` list.
|
||||
static void unpack(OpBuilder &builder, Location loc, Value packed,
|
||||
SmallVectorImpl<Value> &results);
|
||||
|
||||
/// Returns the number of non-aggregate values that would be produced by
|
||||
/// `unpack`.
|
||||
static unsigned getNumUnpackedValues() { return 2; }
|
||||
};
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||
/// conversion patterns with an access to the containing LLVMLowering for the
|
||||
|
|
|
@ -29,16 +29,21 @@ void populateStdToLLVMMemoryConversionPatters(
|
|||
void populateStdToLLVMNonMemoryConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
|
||||
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect.
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateStdToLLVMDefaultFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers = false);
|
||||
|
||||
/// Collect a set of default patterns to convert from the Standard dialect to
|
||||
/// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will
|
||||
/// generate `llvm.alloca` instead of calls to "malloc".
|
||||
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns,
|
||||
bool useAlloca = false);
|
||||
bool useAlloca = false,
|
||||
bool emitCWrappers = false);
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to
|
||||
/// LLVM using the bare pointer calling convention for MemRef function
|
||||
|
@ -53,7 +58,7 @@ void populateStdToLLVMBarePtrConversionPatterns(
|
|||
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
|
||||
/// this may become an enum when we have concrete uses for other options.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
createLowerToLLVMPass(bool useAlloca = false);
|
||||
createLowerToLLVMPass(bool useAlloca = false, bool emitCWrappers = false);
|
||||
|
||||
namespace LLVM {
|
||||
/// Make argument-taking successors of each block distinct. PHI nodes in LLVM
|
||||
|
|
|
@ -30,6 +30,13 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
|
|||
return ("arg" + Twine(arg)).toStringRef(out);
|
||||
}
|
||||
|
||||
/// Returns true if the given name is a valid argument attribute name.
|
||||
inline bool isArgAttrName(StringRef name) {
|
||||
APInt unused;
|
||||
return name.startswith("arg") &&
|
||||
!name.drop_front(3).getAsInteger(/*Radix=*/10, unused);
|
||||
}
|
||||
|
||||
/// Return the name of the attribute used for function results.
|
||||
inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
|
||||
out.clear();
|
||||
|
|
|
@ -113,6 +113,8 @@ private:
|
|||
}
|
||||
|
||||
void declareCudaFunctions(Location loc);
|
||||
void addParamToList(OpBuilder &builder, Location loc, Value param, Value list,
|
||||
unsigned pos, Value one);
|
||||
Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
|
||||
Value generateKernelNameConstant(StringRef name, Location loc,
|
||||
OpBuilder &builder);
|
||||
|
@ -231,6 +233,35 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Emits the IR with the following structure:
|
||||
///
|
||||
/// %data = llvm.alloca 1 x type-of(<param>)
|
||||
/// llvm.store <param>, %data
|
||||
/// %typeErased = llvm.bitcast %data to !llvm<"i8*">
|
||||
/// %addr = llvm.getelementptr <list>[<pos>]
|
||||
/// llvm.store %typeErased, %addr
|
||||
///
|
||||
/// This is necessary to construct the list of arguments passed to the kernel
|
||||
/// function as accepted by cuLaunchKernel, i.e. as a void** that points to list
|
||||
/// of stack-allocated type-erased pointers to the actual arguments.
|
||||
void GpuLaunchFuncToCudaCallsPass::addParamToList(OpBuilder &builder,
|
||||
Location loc, Value param,
|
||||
Value list, unsigned pos,
|
||||
Value one) {
|
||||
auto memLocation = builder.create<LLVM::AllocaOp>(
|
||||
loc, param.getType().cast<LLVM::LLVMType>().getPointerTo(), one,
|
||||
/*alignment=*/1);
|
||||
builder.create<LLVM::StoreOp>(loc, param, memLocation);
|
||||
auto casted =
|
||||
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
|
||||
|
||||
auto index = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
|
||||
builder.getI32IntegerAttr(pos));
|
||||
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), list,
|
||||
ArrayRef<Value>{index});
|
||||
builder.create<LLVM::StoreOp>(loc, casted, gep);
|
||||
}
|
||||
|
||||
// Generates a parameters array to be used with a CUDA kernel launch call. The
|
||||
// arguments are extracted from the launchOp.
|
||||
// The generated code is essentially as follows:
|
||||
|
@ -241,53 +272,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
// return %array
|
||||
Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
||||
OpBuilder &builder) {
|
||||
|
||||
// Get the launch target.
|
||||
auto containingModule = launchOp.getParentOfType<ModuleOp>();
|
||||
if (!containingModule)
|
||||
return {};
|
||||
auto gpuModule = containingModule.lookupSymbol<gpu::GPUModuleOp>(
|
||||
launchOp.getKernelModuleName());
|
||||
if (!gpuModule)
|
||||
return {};
|
||||
auto gpuFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>(launchOp.kernel());
|
||||
if (!gpuFunc)
|
||||
return {};
|
||||
|
||||
unsigned numArgs = gpuFunc.getNumArguments();
|
||||
|
||||
auto numKernelOperands = launchOp.getNumKernelOperands();
|
||||
Location loc = launchOp.getLoc();
|
||||
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
|
||||
builder.getI32IntegerAttr(1));
|
||||
// Provision twice as much for the `array` to allow up to one level of
|
||||
// indirection for each argument.
|
||||
auto arraySize = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands));
|
||||
loc, getInt32Type(), builder.getI32IntegerAttr(numArgs));
|
||||
auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
|
||||
arraySize, /*alignment=*/0);
|
||||
|
||||
unsigned pos = 0;
|
||||
for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
|
||||
auto operand = launchOp.getKernelOperand(idx);
|
||||
auto llvmType = operand.getType().cast<LLVM::LLVMType>();
|
||||
Value memLocation = builder.create<LLVM::AllocaOp>(
|
||||
loc, llvmType.getPointerTo(), one, /*alignment=*/1);
|
||||
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
|
||||
auto casted =
|
||||
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
|
||||
|
||||
// Assume all struct arguments come from MemRef. If this assumption does not
|
||||
// hold anymore then we `launchOp` to lower from MemRefType and not after
|
||||
// LLVMConversion has taken place and the MemRef information is lost.
|
||||
// Extra level of indirection in the `array`:
|
||||
// the descriptor pointer is registered via @mcuMemHostRegisterPtr
|
||||
if (llvmType.isStructTy()) {
|
||||
auto registerFunc =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
|
||||
auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
|
||||
auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
|
||||
ArrayRef<Value>{nullPtr, one});
|
||||
auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
|
||||
builder.getSymbolRefAttr(registerFunc),
|
||||
ArrayRef<Value>{casted, size});
|
||||
Value memLocation = builder.create<LLVM::AllocaOp>(
|
||||
loc, getPointerPointerType(), one, /*alignment=*/1);
|
||||
builder.create<LLVM::StoreOp>(loc, casted, memLocation);
|
||||
casted =
|
||||
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
|
||||
if (!llvmType.isStructTy()) {
|
||||
addParamToList(builder, loc, operand, array, pos++, one);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto index = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(), builder.getI32IntegerAttr(idx));
|
||||
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
|
||||
ArrayRef<Value>{index});
|
||||
builder.create<LLVM::StoreOp>(loc, casted, gep);
|
||||
// Put individual components of a memref descriptor into the flat argument
|
||||
// list. We cannot use unpackMemref from LLVM lowering here because we have
|
||||
// no access to MemRefType that had been lowered away.
|
||||
for (int32_t j = 0, ej = llvmType.getStructNumElements(); j < ej; ++j) {
|
||||
auto elemType = llvmType.getStructElementType(j);
|
||||
if (elemType.isArrayTy()) {
|
||||
for (int32_t k = 0, ek = elemType.getArrayNumElements(); k < ek; ++k) {
|
||||
Value elem = builder.create<LLVM::ExtractValueOp>(
|
||||
loc, elemType.getArrayElementType(), operand,
|
||||
builder.getI32ArrayAttr({j, k}));
|
||||
addParamToList(builder, loc, elem, array, pos++, one);
|
||||
}
|
||||
} else {
|
||||
assert((elemType.isIntegerTy() || elemType.isFloatTy() ||
|
||||
elemType.isDoubleTy() || elemType.isPointerTy()) &&
|
||||
"expected scalar type");
|
||||
Value strct = builder.create<LLVM::ExtractValueOp>(
|
||||
loc, elemType, operand, builder.getI32ArrayAttr(j));
|
||||
addParamToList(builder, loc, strct, array, pos++, one);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
|
@ -392,6 +436,10 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
auto cuFunctionRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||
if (!paramsArray) {
|
||||
launchOp.emitOpError() << "cannot pass given parameters to the kernel";
|
||||
return signalPassFailure();
|
||||
}
|
||||
auto nullpointer =
|
||||
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
|
||||
builder.create<LLVM::CallOp>(
|
||||
|
|
|
@ -564,8 +564,8 @@ struct GPUFuncOpLowering : LLVMOpLowering {
|
|||
// Remap proper input types.
|
||||
TypeConverter::SignatureConversion signatureConversion(
|
||||
gpuFuncOp.front().getNumArguments());
|
||||
for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
|
||||
signatureConversion.addInputs(i, funcType.getFunctionParamType(i));
|
||||
lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false,
|
||||
signatureConversion);
|
||||
|
||||
// Create the new function operation. Only copy those attributes that are
|
||||
// not specific to function modeling.
|
||||
|
@ -651,25 +651,6 @@ struct GPUFuncOpLowering : LLVMOpLowering {
|
|||
rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
|
||||
signatureConversion);
|
||||
|
||||
{
|
||||
// For memref-typed arguments, insert the relevant loads in the beginning
|
||||
// of the block to comply with the LLVM dialect calling convention. This
|
||||
// needs to be done after signature conversion to get the right types.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block &block = llvmFuncOp.front();
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) {
|
||||
if (!en.value().isa<MemRefType>() &&
|
||||
!en.value().isa<UnrankedMemRefType>())
|
||||
continue;
|
||||
|
||||
BlockArgument arg = block.getArgument(en.index());
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
||||
rewriter.replaceUsesOfBlockArgument(arg, loaded);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.eraseOp(gpuFuncOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
|
@ -577,7 +577,8 @@ void ConvertLinalgToLLVMPass::runOnModule() {
|
|||
LinalgTypeConverter converter(&getContext());
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
|
||||
/*emitCWrappers=*/true);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalgToStandardConversionPatterns(patterns, &getContext());
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -43,6 +44,11 @@ static llvm::cl::opt<bool>
|
|||
llvm::cl::desc("Replace emission of malloc/free by alloca"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
static llvm::cl::opt<bool>
|
||||
clEmitCWrappers(PASS_NAME "-emit-c-wrappers",
|
||||
llvm::cl::desc("Emit C-compatible wrapper functions"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
static llvm::cl::opt<bool> clUseBarePtrCallConv(
|
||||
PASS_NAME "-use-bare-ptr-memref-call-conv",
|
||||
llvm::cl::desc("Replace FuncOp's MemRef arguments with "
|
||||
|
@ -66,18 +72,32 @@ LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() {
|
|||
funcArgConverter = structFuncArgTypeConverter;
|
||||
}
|
||||
|
||||
// Callback to convert function argument types. It converts a MemRef function
|
||||
// arguments to a struct that contains the descriptor information. Converted
|
||||
// types are promoted to a pointer to the converted type.
|
||||
LLVM::LLVMType mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type) {
|
||||
auto converted =
|
||||
converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
|
||||
/// Callback to convert function argument types. It converts a MemRef function
|
||||
/// argument to a list of non-aggregate types containing descriptor
|
||||
/// information, and an UnrankedmemRef function argument to a list containing
|
||||
/// the rank and a pointer to a descriptor struct.
|
||||
LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type,
|
||||
SmallVectorImpl<Type> &result) {
|
||||
if (auto memref = type.dyn_cast<MemRefType>()) {
|
||||
auto converted = converter.convertMemRefSignature(memref);
|
||||
if (converted.empty())
|
||||
return failure();
|
||||
result.append(converted.begin(), converted.end());
|
||||
return success();
|
||||
}
|
||||
if (type.isa<UnrankedMemRefType>()) {
|
||||
auto converted = converter.convertUnrankedMemRefSignature();
|
||||
if (converted.empty())
|
||||
return failure();
|
||||
result.append(converted.begin(), converted.end());
|
||||
return success();
|
||||
}
|
||||
auto converted = converter.convertType(type);
|
||||
if (!converted)
|
||||
return {};
|
||||
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>())
|
||||
converted = converted.getPointerTo();
|
||||
return converted;
|
||||
return failure();
|
||||
result.push_back(converted);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert a MemRef type to a bare pointer to the MemRef element type.
|
||||
|
@ -96,15 +116,26 @@ static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter,
|
|||
}
|
||||
|
||||
/// Callback to convert function argument types. It converts MemRef function
|
||||
/// arguments to bare pointers to the MemRef element type. Converted types are
|
||||
/// not promoted to pointers.
|
||||
LLVM::LLVMType mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type) {
|
||||
/// arguments to bare pointers to the MemRef element type.
|
||||
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||
Type type,
|
||||
SmallVectorImpl<Type> &result) {
|
||||
// TODO: Add support for unranked memref.
|
||||
if (auto memrefTy = type.dyn_cast<MemRefType>())
|
||||
return convertMemRefTypeToBarePtr(converter, memrefTy)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
return converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
|
||||
auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy);
|
||||
if (!llvmTy)
|
||||
return failure();
|
||||
|
||||
result.push_back(llvmTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto llvmTy = converter.convertType(type);
|
||||
if (!llvmTy)
|
||||
return failure();
|
||||
|
||||
result.push_back(llvmTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization.
|
||||
|
@ -165,6 +196,33 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
|||
return converted.getPointerTo();
|
||||
}
|
||||
|
||||
/// In signatures, MemRef descriptors are expanded into lists of non-aggregate
|
||||
/// values.
|
||||
SmallVector<Type, 5>
|
||||
LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
|
||||
SmallVector<Type, 5> results;
|
||||
assert(isStrided(type) &&
|
||||
"Non-strided layout maps must have been normalized away");
|
||||
|
||||
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
||||
if (!elementType)
|
||||
return {};
|
||||
auto indexTy = getIndexType();
|
||||
|
||||
results.insert(results.begin(), 2,
|
||||
elementType.getPointerTo(type.getMemorySpace()));
|
||||
results.push_back(indexTy);
|
||||
auto rank = type.getRank();
|
||||
results.insert(results.end(), 2 * rank, indexTy);
|
||||
return results;
|
||||
}
|
||||
|
||||
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
|
||||
/// pointer to descriptor".
|
||||
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
|
||||
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
|
||||
}
|
||||
|
||||
// Function types are converted to LLVM Function types by recursively converting
|
||||
// argument and result types. If MLIR Function has zero results, the LLVM
|
||||
// Function has one VoidType result. If MLIR Function has more than one result,
|
||||
|
@ -175,9 +233,8 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
|||
// Convert argument types one by one and check for errors.
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
Type type = en.value();
|
||||
auto converted = customizations.funcArgConverter(*this, type)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!converted)
|
||||
SmallVector<Type, 8> converted;
|
||||
if (failed(customizations.funcArgConverter(*this, type, converted)))
|
||||
return {};
|
||||
result.addInputs(en.index(), converted);
|
||||
}
|
||||
|
@ -199,6 +256,47 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
|||
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
|
||||
}
|
||||
|
||||
/// Converts the function type to a C-compatible format, in particular using
|
||||
/// pointers to memref descriptors for arguments.
|
||||
LLVM::LLVMType
|
||||
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
||||
SmallVector<LLVM::LLVMType, 4> inputs;
|
||||
|
||||
for (Type t : type.getInputs()) {
|
||||
auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!converted)
|
||||
return {};
|
||||
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
|
||||
converted = converted.getPointerTo();
|
||||
inputs.push_back(converted);
|
||||
}
|
||||
|
||||
LLVM::LLVMType resultType =
|
||||
type.getNumResults() == 0
|
||||
? LLVM::LLVMType::getVoidTy(llvmDialect)
|
||||
: unwrap(packFunctionResults(type.getResults()));
|
||||
if (!resultType)
|
||||
return {};
|
||||
|
||||
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
||||
}
|
||||
|
||||
/// Creates descriptor structs from individual values constituting them.
|
||||
Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
|
||||
Type type,
|
||||
ArrayRef<Value> values,
|
||||
Location loc) {
|
||||
if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
|
||||
unrankedMemRefType, values)
|
||||
.getDefiningOp();
|
||||
|
||||
auto memRefType = type.dyn_cast<MemRefType>();
|
||||
assert(memRefType && "1->N conversion is only supported for memrefs");
|
||||
return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
|
||||
.getDefiningOp();
|
||||
}
|
||||
|
||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
||||
// contains:
|
||||
// 1. the pointer to the data buffer, followed by
|
||||
|
@ -473,6 +571,85 @@ LLVM::LLVMType MemRefDescriptor::getElementType() {
|
|||
kAlignedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Creates a MemRef descriptor structure from a list of individual values
|
||||
/// composing that descriptor, in the following order:
|
||||
/// - allocated pointer;
|
||||
/// - aligned pointer;
|
||||
/// - offset;
|
||||
/// - <rank> sizes;
|
||||
/// - <rank> shapes;
|
||||
/// where <rank> is the MemRef rank as provided in `type`.
|
||||
Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &converter, MemRefType type,
|
||||
ValueRange values) {
|
||||
Type llvmType = converter.convertType(type);
|
||||
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
|
||||
|
||||
d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
|
||||
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
|
||||
d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
|
||||
|
||||
int64_t rank = type.getRank();
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
|
||||
d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
|
||||
}
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
/// Builds IR extracting individual elements of a MemRef descriptor structure
|
||||
/// and returning them as `results` list.
|
||||
void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
|
||||
MemRefType type,
|
||||
SmallVectorImpl<Value> &results) {
|
||||
int64_t rank = type.getRank();
|
||||
results.reserve(results.size() + getNumUnpackedValues(type));
|
||||
|
||||
MemRefDescriptor d(packed);
|
||||
results.push_back(d.allocatedPtr(builder, loc));
|
||||
results.push_back(d.alignedPtr(builder, loc));
|
||||
results.push_back(d.offset(builder, loc));
|
||||
for (int64_t i = 0; i < rank; ++i)
|
||||
results.push_back(d.size(builder, loc, i));
|
||||
for (int64_t i = 0; i < rank; ++i)
|
||||
results.push_back(d.stride(builder, loc, i));
|
||||
}
|
||||
|
||||
/// Returns the number of non-aggregate values that would be produced by
|
||||
/// `unpack`.
|
||||
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
|
||||
// Two pointers, offset, <rank> sizes, <rank> shapes.
|
||||
return 3 + 2 * type.getRank();
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* MemRefDescriptorView implementation. */
|
||||
/*============================================================================*/
|
||||
|
||||
MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
|
||||
: rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
|
||||
|
||||
Value MemRefDescriptorView::allocatedPtr() {
|
||||
return elements[kAllocatedPtrPosInMemRefDescriptor];
|
||||
}
|
||||
|
||||
Value MemRefDescriptorView::alignedPtr() {
|
||||
return elements[kAlignedPtrPosInMemRefDescriptor];
|
||||
}
|
||||
|
||||
Value MemRefDescriptorView::offset() {
|
||||
return elements[kOffsetPosInMemRefDescriptor];
|
||||
}
|
||||
|
||||
Value MemRefDescriptorView::size(unsigned pos) {
|
||||
return elements[kSizePosInMemRefDescriptor + pos];
|
||||
}
|
||||
|
||||
Value MemRefDescriptorView::stride(unsigned pos) {
|
||||
return elements[kSizePosInMemRefDescriptor + rank + pos];
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* UnrankedMemRefDescriptor implementation */
|
||||
/*============================================================================*/
|
||||
|
@ -504,6 +681,34 @@ void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
|
|||
Location loc, Value v) {
|
||||
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
|
||||
}
|
||||
|
||||
/// Builds IR populating an unranked MemRef descriptor structure from a list
|
||||
/// of individual constituent values in the following order:
|
||||
/// - rank of the memref;
|
||||
/// - pointer to the memref descriptor.
|
||||
Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &converter,
|
||||
UnrankedMemRefType type,
|
||||
ValueRange values) {
|
||||
Type llvmType = converter.convertType(type);
|
||||
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
|
||||
|
||||
d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
|
||||
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
|
||||
return d;
|
||||
}
|
||||
|
||||
/// Builds IR extracting individual elements that compose an unranked memref
|
||||
/// descriptor and returns them as `results` list.
|
||||
void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
|
||||
Value packed,
|
||||
SmallVectorImpl<Value> &results) {
|
||||
UnrankedMemRefDescriptor d(packed);
|
||||
results.reserve(results.size() + 2);
|
||||
results.push_back(d.rank(builder, loc));
|
||||
results.push_back(d.memRefDescPtr(builder, loc));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
||||
// provided as template argument. Carries a reference to the LLVM dialect in
|
||||
|
@ -551,9 +756,144 @@ protected:
|
|||
LLVM::LLVMDialect &dialect;
|
||||
};
|
||||
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
|
||||
attr.first.is(impl::getTypeAttrName()) ||
|
||||
attr.first.is("std.varargs") ||
|
||||
(filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
|
||||
continue;
|
||||
result.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
||||
/// arguments instead of unpacked arguments. This function can be called from C
|
||||
/// by passing a pointer to a C struct corresponding to a memref descriptor.
|
||||
/// Internally, the auxiliary function unpacks the descriptor into individual
|
||||
/// components and forwards them to `newFuncOp`.
|
||||
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
||||
auto type = funcOp.getType();
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
|
||||
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
|
||||
attributes);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
|
||||
|
||||
SmallVector<Value, 8> args;
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
Value arg = wrapperFuncOp.getArgument(en.index());
|
||||
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
||||
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
|
||||
continue;
|
||||
}
|
||||
if (en.value().isa<UnrankedMemRefType>()) {
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
||||
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
|
||||
continue;
|
||||
}
|
||||
|
||||
args.push_back(wrapperFuncOp.getArgument(en.index()));
|
||||
}
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
|
||||
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
}
|
||||
|
||||
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
||||
/// arguments instead of unpacked arguments. Creates a body for the (external)
|
||||
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
|
||||
/// individual arguments into this descriptor and passes a pointer to it into
|
||||
/// the auxiliary function. This auxiliary external function is now compatible
|
||||
/// with functions defined in C using pointers to C structs corresponding to a
|
||||
/// memref descriptor.
|
||||
static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
LLVM::LLVMType wrapperType =
|
||||
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
|
||||
// This conversion can only fail if it could not convert one of the argument
|
||||
// types. But since it has been applies to a non-wrapper function before, it
|
||||
// should have failed earlier and not reach this point at all.
|
||||
assert(wrapperType && "unexpected type conversion failure");
|
||||
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
|
||||
|
||||
// Create the auxiliary function.
|
||||
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
wrapperType, LLVM::Linkage::External, attributes);
|
||||
|
||||
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
|
||||
|
||||
// Get a ValueRange containing argument types. Note that ValueRange is
|
||||
// currently not constructible from a pair of iterators pointing to
|
||||
// BlockArgument.
|
||||
FunctionType type = funcOp.getType();
|
||||
SmallVector<Value, 8> args;
|
||||
args.reserve(type.getNumInputs());
|
||||
auto wrapperArgIters = newFuncOp.getArguments();
|
||||
SmallVector<Value, 8> wrapperArgs(wrapperArgIters.begin(),
|
||||
wrapperArgIters.end());
|
||||
ValueRange wrapperArgsRange(wrapperArgs);
|
||||
|
||||
// Iterate over the inputs of the original function and pack values into
|
||||
// memref descriptors if the original type is a memref.
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
Value arg;
|
||||
int numToDrop = 1;
|
||||
auto memRefType = en.value().dyn_cast<MemRefType>();
|
||||
auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
|
||||
if (memRefType || unrankedMemRefType) {
|
||||
numToDrop = memRefType
|
||||
? MemRefDescriptor::getNumUnpackedValues(memRefType)
|
||||
: UnrankedMemRefDescriptor::getNumUnpackedValues();
|
||||
Value packed =
|
||||
memRefType
|
||||
? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
|
||||
wrapperArgsRange.take_front(numToDrop))
|
||||
: UnrankedMemRefDescriptor::pack(
|
||||
builder, loc, typeConverter, unrankedMemRefType,
|
||||
wrapperArgsRange.take_front(numToDrop));
|
||||
|
||||
auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
|
||||
Value one = builder.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.convertType(builder.getIndexType()),
|
||||
builder.getIntegerAttr(builder.getIndexType(), 1));
|
||||
Value allocated =
|
||||
builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
|
||||
builder.create<LLVM::StoreOp>(loc, packed, allocated);
|
||||
arg = allocated;
|
||||
} else {
|
||||
arg = wrapperArgsRange[0];
|
||||
}
|
||||
|
||||
args.push_back(arg);
|
||||
wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
|
||||
}
|
||||
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
|
||||
|
||||
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
|
||||
builder.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
}
|
||||
|
||||
struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
|
||||
protected:
|
||||
using LLVMLegalizationPattern::LLVMLegalizationPattern;
|
||||
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
|
||||
using UnsignedTypePair = std::pair<unsigned, Type>;
|
||||
|
||||
// Gather the positions and types of memref-typed arguments in a given
|
||||
|
@ -579,14 +919,24 @@ protected:
|
|||
auto llvmType = lowering.convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
|
||||
// Only retain those attributes that are not constructed by build.
|
||||
// Propagate argument attributes to all converted arguments obtained after
|
||||
// converting a given original argument.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
for (const auto &attr : funcOp.getAttrs()) {
|
||||
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
|
||||
attr.first.is(impl::getTypeAttrName()) ||
|
||||
attr.first.is("std.varargs"))
|
||||
filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true,
|
||||
attributes);
|
||||
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||
auto attr = impl::getArgAttrDict(funcOp, i);
|
||||
if (!attr)
|
||||
continue;
|
||||
attributes.push_back(attr);
|
||||
|
||||
auto mapping = result.getInputMapping(i);
|
||||
assert(mapping.hasValue() && "unexpected deletion of function argument");
|
||||
|
||||
SmallString<8> name;
|
||||
for (size_t j = mapping->inputNo; j < mapping->size; ++j) {
|
||||
impl::getArgAttrName(j, name);
|
||||
attributes.push_back(rewriter.getNamedAttr(name, attr));
|
||||
}
|
||||
}
|
||||
|
||||
// Create an LLVM function, use external linkage by default until MLIR
|
||||
|
@ -607,34 +957,33 @@ protected:
|
|||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
/// information.
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
using FuncOpConversionBase::FuncOpConversionBase;
|
||||
FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter,
|
||||
bool emitCWrappers)
|
||||
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
||||
// Store the positions of memref-typed arguments so that we can emit loads
|
||||
// from them to follow the calling convention.
|
||||
SmallVector<UnsignedTypePair, 4> promotedArgsInfo;
|
||||
getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
|
||||
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
|
||||
// Insert loads from memref descriptor pointers in function bodies.
|
||||
if (!newFuncOp.getBody().empty()) {
|
||||
Block *firstBlock = &newFuncOp.getBody().front();
|
||||
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
|
||||
for (const auto &argInfo : promotedArgsInfo) {
|
||||
BlockArgument arg = firstBlock->getArgument(argInfo.first);
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
|
||||
rewriter.replaceUsesOfBlockArgument(arg, loaded);
|
||||
}
|
||||
if (emitWrappers) {
|
||||
if (newFuncOp.isExternal())
|
||||
wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp,
|
||||
newFuncOp);
|
||||
else
|
||||
wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp,
|
||||
newFuncOp);
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
/// If true, also create the adaptor functions having signatures compatible
|
||||
/// with those produced by clang.
|
||||
const bool emitWrappers;
|
||||
};
|
||||
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
|
||||
|
@ -2273,14 +2622,17 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
|
|||
}
|
||||
|
||||
void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter);
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool emitCWrappers) {
|
||||
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter,
|
||||
emitCWrappers);
|
||||
}
|
||||
|
||||
void mlir::populateStdToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
bool useAlloca) {
|
||||
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns);
|
||||
bool useAlloca, bool emitCWrappers) {
|
||||
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
|
||||
emitCWrappers);
|
||||
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
|
||||
}
|
||||
|
@ -2346,13 +2698,20 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
|
|||
for (auto it : llvm::zip(opOperands, operands)) {
|
||||
auto operand = std::get<0>(it);
|
||||
auto llvmOperand = std::get<1>(it);
|
||||
if (!operand.getType().isa<MemRefType>() &&
|
||||
!operand.getType().isa<UnrankedMemRefType>()) {
|
||||
promotedOperands.push_back(operand);
|
||||
|
||||
if (operand.getType().isa<UnrankedMemRefType>()) {
|
||||
UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
|
||||
promotedOperands);
|
||||
continue;
|
||||
}
|
||||
promotedOperands.push_back(
|
||||
promoteOneMemRefDescriptor(loc, llvmOperand, builder));
|
||||
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
|
||||
MemRefDescriptor::unpack(builder, loc, llvmOperand,
|
||||
operand.getType().cast<MemRefType>(),
|
||||
promotedOperands);
|
||||
continue;
|
||||
}
|
||||
|
||||
promotedOperands.push_back(operand);
|
||||
}
|
||||
return promotedOperands;
|
||||
}
|
||||
|
@ -2362,11 +2721,21 @@ namespace {
|
|||
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
/// Creates an LLVM lowering pass.
|
||||
explicit LLVMLoweringPass(bool useAlloca = false,
|
||||
bool useBarePtrCallConv = false)
|
||||
: useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv) {}
|
||||
bool useBarePtrCallConv = false,
|
||||
bool emitCWrappers = false)
|
||||
: useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv),
|
||||
emitCWrappers(emitCWrappers) {}
|
||||
|
||||
/// Run the dialect converter on the module.
|
||||
void runOnModule() override {
|
||||
if (useBarePtrCallConv && emitCWrappers) {
|
||||
getModule().emitError()
|
||||
<< "incompatible conversion options: bare-pointer calling convention "
|
||||
"and C wrapper emission";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ModuleOp m = getModule();
|
||||
LLVM::ensureDistinctSuccessors(m);
|
||||
|
||||
|
@ -2380,7 +2749,8 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||
populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
|
||||
useAlloca);
|
||||
else
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca,
|
||||
emitCWrappers);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
|
@ -2393,19 +2763,23 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||
|
||||
/// Convert memrefs to bare pointers in function signatures.
|
||||
bool useBarePtrCallConv;
|
||||
|
||||
/// Emit wrappers for C-compatible pointer-to-struct memref descriptors.
|
||||
bool emitCWrappers;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||
mlir::createLowerToLLVMPass(bool useAlloca) {
|
||||
return std::make_unique<LLVMLoweringPass>(useAlloca);
|
||||
mlir::createLowerToLLVMPass(bool useAlloca, bool emitCWrappers) {
|
||||
return std::make_unique<LLVMLoweringPass>(useAlloca, emitCWrappers);
|
||||
}
|
||||
|
||||
static PassRegistration<LLVMLoweringPass>
|
||||
pass("convert-std-to-llvm",
|
||||
pass(PASS_NAME,
|
||||
"Convert scalar and vector operations from the "
|
||||
"Standard to the LLVM dialect",
|
||||
[] {
|
||||
return std::make_unique<LLVMLoweringPass>(
|
||||
clUseAlloca.getValue(), clUseBarePtrCallConv.getValue());
|
||||
clUseAlloca.getValue(), clUseBarePtrCallConv.getValue(),
|
||||
clEmitCWrappers.getValue());
|
||||
});
|
||||
|
|
|
@ -90,26 +90,27 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
|||
return launchOp.emitOpError("kernel function is missing the '")
|
||||
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
||||
|
||||
// TODO(ntv,zinenko,herhut): if the kernel function has been converted to
|
||||
// the LLVM dialect but the caller hasn't (which happens during the
|
||||
// separate compilation), do not check type correspondance as it would
|
||||
// require the verifier to be aware of the LLVM type conversion.
|
||||
if (kernelLLVMFunction)
|
||||
return success();
|
||||
|
||||
unsigned actualNumArguments = launchOp.getNumKernelOperands();
|
||||
unsigned expectedNumArguments = kernelLLVMFunction
|
||||
? kernelLLVMFunction.getNumArguments()
|
||||
: kernelGPUFunction.getNumArguments();
|
||||
unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
|
||||
if (expectedNumArguments != actualNumArguments)
|
||||
return launchOp.emitOpError("got ")
|
||||
<< actualNumArguments << " kernel operands but expected "
|
||||
<< expectedNumArguments;
|
||||
|
||||
// Due to the ordering of the current impl of lowering and LLVMLowering,
|
||||
// type checks need to be temporarily disabled.
|
||||
// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
|
||||
// to encode target module" has landed.
|
||||
// auto functionType = kernelFunc.getType();
|
||||
// for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
|
||||
// if (getKernelOperand(i).getType() != functionType.getInput(i)) {
|
||||
// return emitOpError("type of function argument ")
|
||||
// << i << " does not match";
|
||||
// }
|
||||
// }
|
||||
auto functionType = kernelGPUFunction.getType();
|
||||
for (unsigned i = 0; i < expectedNumArguments; ++i) {
|
||||
if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
|
||||
return launchOp.emitOpError("type of function argument ")
|
||||
<< i << " does not match";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
});
|
||||
|
|
|
@ -401,8 +401,7 @@ Block *ArgConverter::applySignatureConversion(
|
|||
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
|
||||
Operation *cast = typeConverter->materializeConversion(
|
||||
rewriter, origArg.getType(), replArgs, loc);
|
||||
assert(cast->getNumResults() == 1 &&
|
||||
cast->getNumOperands() == replArgs.size());
|
||||
assert(cast->getNumResults() == 1);
|
||||
mapping.map(origArg, cast->getResult(0));
|
||||
info.argInfo[i] =
|
||||
ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
|
||||
|
|
|
@ -6,8 +6,8 @@ module attributes {gpu.container_module} {
|
|||
// CHECK: llvm.mlir.global internal constant @[[global:.*]]("CUBIN")
|
||||
|
||||
gpu.module @kernel_module attributes {nvvm.cubin = "CUBIN"} {
|
||||
gpu.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
|
||||
gpu.return
|
||||
llvm.func @kernel(%arg0: !llvm.float, %arg1: !llvm<"float*">) attributes {gpu.kernel} {
|
||||
llvm.return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,25 +1,11 @@
|
|||
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> {dialect.a = true, dialect.b = 4 : i64}) {
|
||||
// CHECK-NEXT: llvm.load %arg0 : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-LABEL: func @check_attributes
|
||||
// When expanding the memref to multiple arguments, argument attributes are replicated.
|
||||
// CHECK-COUNT-7: {dialect.a = true, dialect.b = 4 : i64}
|
||||
func @check_attributes(%static: memref<10x20xf32> {dialect.a = true, dialect.b = 4 : i64 }) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = load %static[%c0, %c0]: memref<10x20xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @external_func(!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
|
||||
// CHECK: func @call_external(%[[arg:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
|
||||
// CHECK: %[[ld:.*]] = llvm.load %[[arg]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[alloca:.*]] = llvm.alloca %[[c1]] x !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK: llvm.store %[[ld]], %[[alloca]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK: call @external_func(%[[alloca]]) : (!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
func @external_func(memref<10x20xf32>)
|
||||
|
||||
func @call_external(%static: memref<10x20xf32>) {
|
||||
call @external_func(%static) : (memref<10x20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,25 @@
|
|||
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @check_strided_memref_arguments(
|
||||
// CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-LABEL: func @check_strided_memref_arguments(
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>,
|
||||
%dynamic : memref<?x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>,
|
||||
%mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
|
||||
// CHECK-LABEL: func @check_arguments
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) {
|
||||
return
|
||||
}
|
||||
|
@ -16,7 +27,7 @@ func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %m
|
|||
// CHECK-LABEL: func @mixed_alloc(
|
||||
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> {
|
||||
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
|
||||
// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
|
@ -45,10 +56,9 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
|
|||
return %0 : memref<?x42x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) {
|
||||
// CHECK-LABEL: func @mixed_dealloc
|
||||
func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
|
||||
dealloc %arg0 : memref<?x42x?xf32>
|
||||
|
@ -59,7 +69,7 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
|
|||
// CHECK-LABEL: func @dynamic_alloc(
|
||||
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
|
||||
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
|
||||
// CHECK: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
|
@ -83,10 +93,9 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
|
|||
return %0 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
|
||||
// CHECK-LABEL: func @dynamic_dealloc
|
||||
func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
|
||||
dealloc %arg0 : memref<?x?xf32>
|
||||
|
@ -94,10 +103,12 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_load(
|
||||
// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">,
|
||||
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
|
||||
// CHECK: %[[I:.*]]: !llvm.i64,
|
||||
// CHECK: %[[J:.*]]: !llvm.i64)
|
||||
func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
|
@ -112,10 +123,8 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_load(
|
||||
// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
|
||||
func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
|
@ -131,8 +140,7 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
|
|||
|
||||
// CHECK-LABEL: func @prefetch
|
||||
func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
|
@ -161,8 +169,7 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
|
|||
|
||||
// CHECK-LABEL: func @dynamic_store
|
||||
func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
|
@ -171,15 +178,14 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
|
|||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
|
||||
store %val, %dynamic[%i, %j] : memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_store
|
||||
func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
|
@ -188,74 +194,66 @@ func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32)
|
|||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
|
||||
store %val, %mixed[%i, %j] : memref<42x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_static_to_dynamic
|
||||
func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_static_to_mixed
|
||||
func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_dynamic_to_static
|
||||
func @memref_cast_dynamic_to_static(%dynamic : memref<?x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_dynamic_to_mixed
|
||||
func @memref_cast_dynamic_to_mixed(%dynamic : memref<?x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_mixed_to_dynamic
|
||||
func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_mixed_to_static
|
||||
func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_mixed_to_mixed
|
||||
func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[ld]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.bitcast %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> to !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
%0 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
|
||||
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
|
||||
// CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
|
||||
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
|
||||
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
|
||||
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
|
||||
|
@ -266,19 +264,17 @@ func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
|
||||
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
|
||||
// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
|
||||
// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
|
||||
// CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i8* }">
|
||||
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
|
||||
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
|
||||
// CHECK-LABEL: func @mixed_memref_dim
|
||||
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
|
||||
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
%0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
%1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
|
||||
// CHECK-NEXT: llvm.extractvalue %[[ld]][3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
%2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>
|
||||
|
|
|
@ -18,12 +18,12 @@ func @fifth_order_left(%arg0: (((() -> ()) -> ()) -> ()) -> ())
|
|||
//CHECK: llvm.func @fifth_order_right(!llvm<"void ()* ()* ()* ()*">)
|
||||
func @fifth_order_right(%arg0: () -> (() -> (() -> (() -> ()))))
|
||||
|
||||
// Check that memrefs are converted to pointers-to-struct if appear as function arguments.
|
||||
// CHECK: llvm.func @memref_call_conv(!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">)
|
||||
// Check that memrefs are converted to argument packs if appear as function arguments.
|
||||
// CHECK: llvm.func @memref_call_conv(!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
|
||||
func @memref_call_conv(%arg0: memref<?xf32>)
|
||||
|
||||
// Same in nested functions.
|
||||
// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void ({ float*, float*, i64, [1 x i64], [1 x i64] }*)*">)
|
||||
// CHECK: llvm.func @memref_call_conv_nested(!llvm<"void (float*, float*, i64, i64, i64)*">)
|
||||
func @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
|
||||
|
||||
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm<"void ()*">) -> !llvm<"void ()*"> {
|
||||
|
|
|
@ -10,7 +10,10 @@ func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
// CHECK-LABEL: func @check_static_return
|
||||
// CHECK-COUNT-2: !llvm<"float*">
|
||||
// CHECK-COUNT-5: !llvm.i64
|
||||
// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// BAREPTR-LABEL: func @check_static_return
|
||||
// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
|
||||
|
@ -76,11 +79,10 @@ func @zero_d_alloc() -> memref<f32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
|
||||
// CHECK-LABEL: func @zero_d_dealloc
|
||||
// BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) {
|
||||
func @zero_d_dealloc(%arg0: memref<f32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
|
||||
|
||||
|
@ -96,7 +98,7 @@ func @zero_d_dealloc(%arg0: memref<f32>) {
|
|||
// CHECK-LABEL: func @aligned_1d_alloc(
|
||||
// BAREPTR-LABEL: func @aligned_1d_alloc(
|
||||
func @aligned_1d_alloc() -> memref<42xf32> {
|
||||
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
|
@ -150,13 +152,13 @@ func @aligned_1d_alloc() -> memref<42xf32> {
|
|||
// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
// BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
func @static_alloc() -> memref<32x18xf32> {
|
||||
// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
|
||||
// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
|
||||
|
@ -177,11 +179,10 @@ func @static_alloc() -> memref<32x18xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) {
|
||||
// CHECK-LABEL: func @static_dealloc
|
||||
// BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) {
|
||||
func @static_dealloc(%static: memref<10x8xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
|
||||
|
||||
|
@ -194,11 +195,10 @@ func @static_dealloc(%static: memref<10x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float {
|
||||
// CHECK-LABEL: func @zero_d_load
|
||||
// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float
|
||||
func @zero_d_load(%arg0: memref<f32>) -> f32 {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*">
|
||||
|
@ -214,20 +214,22 @@ func @zero_d_load(%arg0: memref<f32>) -> f32 {
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @static_load(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">,
|
||||
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64
|
||||
// CHECK: %[[I:.*]]: !llvm.i64,
|
||||
// CHECK: %[[J:.*]]: !llvm.i64)
|
||||
// BAREPTR-LABEL: func @static_load
|
||||
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) {
|
||||
func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*">
|
||||
|
||||
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
@ -246,15 +248,14 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) {
|
||||
// CHECK-LABEL: func @zero_d_store
|
||||
// BAREPTR-LABEL: func @zero_d_store
|
||||
// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float)
|
||||
func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
|
||||
|
||||
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
|
||||
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
|
@ -270,17 +271,16 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
|
|||
// BAREPTR-LABEL: func @static_store
|
||||
// BAREPTR-SAME: %[[A:.*]]: !llvm<"float*">
|
||||
func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*">
|
||||
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
|
||||
|
||||
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
|
@ -298,11 +298,10 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
|
||||
// CHECK-LABEL: func @static_memref_dim
|
||||
// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) {
|
||||
func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
|
||||
// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
%0 = dim %static, 0 : memref<42x32x15x13x27xf32>
|
||||
|
|
|
@ -728,9 +728,15 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @subview(
|
||||
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
|
||||
// CHECK-COUNT-2: !llvm<"float*">,
|
||||
// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64,
|
||||
// CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64,
|
||||
// CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64,
|
||||
// CHECK: %[[ARG2:.*]]: !llvm.i64)
|
||||
func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// The last "insertvalue" that populates the memref descriptor from the function arguments.
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
|
||||
|
||||
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
@ -754,9 +760,10 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @subview_const_size(
|
||||
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
|
||||
func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// The last "insertvalue" that populates the memref descriptor from the function arguments.
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
|
||||
|
||||
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
@ -782,9 +789,10 @@ func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 +
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @subview_const_stride(
|
||||
// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
|
||||
func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// The last "insertvalue" that populates the memref descriptor from the function arguments.
|
||||
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
|
||||
|
||||
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
// RUN: mlir-opt %s -convert-std-to-llvm -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @address_space(
|
||||
// CHECK: %{{.*}}: !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">)
|
||||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">
|
||||
// CHECK-SAME: !llvm<"float addrspace(7)*">
|
||||
func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
|
||||
%0 = alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5>
|
||||
%1 = constant 7 : index
|
||||
|
|
|
@ -175,24 +175,22 @@ module attributes {gpu.container_module} {
|
|||
|
||||
// -----
|
||||
|
||||
gpu.module @kernels {
|
||||
gpu.func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
|
||||
gpu.return
|
||||
module attributes {gpu.container_module} {
|
||||
gpu.module @kernels {
|
||||
gpu.func @kernel_1(%arg1 : f32) attributes { gpu.kernel } {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
|
||||
// expected-err@+1 {{type of function argument 0 does not match}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
|
||||
{kernel = "kernel_1", kernel_module = @kernels}
|
||||
: (index, index, index, index, index, index, f32) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Due to the ordering of the current impl of lowering and LLVMLowering, type
|
||||
// checks need to be temporarily disabled.
|
||||
// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
|
||||
// to encode target module" has landed.
|
||||
// func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
|
||||
// // expected-err@+1 {{type of function argument 0 does not match}}
|
||||
// "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
|
||||
// {kernel = "kernel_1"}
|
||||
// : (index, index, index, index, index, index, f32) -> ()
|
||||
// return
|
||||
// }
|
||||
|
||||
// -----
|
||||
|
||||
func @illegal_dimension() {
|
||||
|
|
|
@ -52,9 +52,11 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
|
|||
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, float*, i64 }*">) {
|
||||
// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
|
||||
// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, float*, i64 }*">) -> ()
|
||||
// CHECK-LABEL: func @dot
|
||||
// CHECK: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}) :
|
||||
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64
|
||||
|
||||
func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?, 1]>) {
|
||||
%c0 = constant 0 : index
|
||||
|
@ -83,7 +85,9 @@ func @copy(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memre
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @copy
|
||||
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
|
||||
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32({{.*}}) :
|
||||
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
|
||||
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
%0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
|
@ -128,9 +132,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
|
|||
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// Call external copy after promoting input and output structs to pointers
|
||||
// CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
|
||||
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) -> ()
|
||||
// Call external copy.
|
||||
// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32
|
||||
|
||||
#matmul_accesses = [
|
||||
affine_map<(m, n, k) -> (m, k)>,
|
||||
|
@ -163,7 +166,10 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul_vec_impl(
|
||||
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) :
|
||||
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
|
||||
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
|
||||
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
|
||||
|
@ -195,7 +201,10 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
|
|||
}
|
||||
// CHECK-LABEL: func @matmul_vec_indexed(
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}) :
|
||||
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
|
||||
func @reshape_static(%arg0: memref<3x4x5xf32>) {
|
||||
// Reshapes that expand and collapse back a contiguous tensor with some 1's.
|
||||
|
|
|
@ -15,32 +15,35 @@
|
|||
#include <assert.h>
|
||||
#include <iostream>
|
||||
|
||||
extern "C" void linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X,
|
||||
float f) {
|
||||
extern "C" void
|
||||
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
|
||||
X->data[X->offset] = f;
|
||||
}
|
||||
|
||||
extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
|
||||
float f) {
|
||||
extern "C" void
|
||||
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
|
||||
float f) {
|
||||
for (unsigned i = 0; i < X->sizes[0]; ++i)
|
||||
*(X->data + X->offset + i * X->strides[0]) = f;
|
||||
}
|
||||
|
||||
extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
|
||||
float f) {
|
||||
extern "C" void
|
||||
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
|
||||
float f) {
|
||||
for (unsigned i = 0; i < X->sizes[0]; ++i)
|
||||
for (unsigned j = 0; j < X->sizes[1]; ++j)
|
||||
*(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
|
||||
}
|
||||
|
||||
extern "C" void linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
|
||||
StridedMemRefType<float, 0> *O) {
|
||||
extern "C" void
|
||||
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
|
||||
StridedMemRefType<float, 0> *O) {
|
||||
O->data[O->offset] = I->data[I->offset];
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
|
||||
StridedMemRefType<float, 1> *O) {
|
||||
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
|
||||
StridedMemRefType<float, 1> *O) {
|
||||
if (I->sizes[0] != O->sizes[0]) {
|
||||
std::cerr << "Incompatible strided memrefs\n";
|
||||
printMemRefMetaData(std::cerr, *I);
|
||||
|
@ -52,9 +55,8 @@ linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
|
|||
I->data[I->offset + i * I->strides[0]];
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
|
||||
StridedMemRefType<float, 2> *O) {
|
||||
extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
|
||||
StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
|
||||
if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
|
||||
std::cerr << "Incompatible strided memrefs\n";
|
||||
printMemRefMetaData(std::cerr, *I);
|
||||
|
@ -69,10 +71,9 @@ linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
|
|||
I->data[I->offset + i * si0 + j * si1];
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
|
||||
StridedMemRefType<float, 1> *Y,
|
||||
StridedMemRefType<float, 0> *Z) {
|
||||
extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
|
||||
StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
|
||||
StridedMemRefType<float, 0> *Z) {
|
||||
if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
|
||||
std::cerr << "Incompatible strided memrefs\n";
|
||||
printMemRefMetaData(std::cerr, *X);
|
||||
|
@ -85,7 +86,7 @@ linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
|
|||
Y->data + Y->offset, Y->strides[0]);
|
||||
}
|
||||
|
||||
extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
|
||||
extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
|
||||
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
|
||||
StridedMemRefType<float, 2> *C) {
|
||||
if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
|
||||
|
|
|
@ -25,33 +25,34 @@
|
|||
#endif // _WIN32
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
|
||||
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
|
||||
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X, float f);
|
||||
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
|
||||
float f);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
|
||||
StridedMemRefType<float, 0> *O);
|
||||
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
|
||||
StridedMemRefType<float, 0> *O);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
|
||||
StridedMemRefType<float, 1> *O);
|
||||
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
|
||||
StridedMemRefType<float, 1> *O);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
|
||||
StridedMemRefType<float, 2> *O);
|
||||
_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
|
||||
StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
|
||||
StridedMemRefType<float, 1> *Y,
|
||||
StridedMemRefType<float, 0> *Z);
|
||||
_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
|
||||
StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
|
||||
StridedMemRefType<float, 0> *Z);
|
||||
|
||||
extern "C" MLIR_CBLAS_INTERFACE_EXPORT void
|
||||
linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
|
||||
_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
|
||||
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
|
||||
StridedMemRefType<float, 2> *C);
|
||||
|
||||
|
|
|
@ -261,23 +261,27 @@ template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
|
|||
// Currently exposed C API.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_i8(UnrankedMemRefType<int8_t> *M);
|
||||
_mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_f32(UnrankedMemRefType<float> *M);
|
||||
_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
|
||||
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(int64_t rank,
|
||||
void *ptr);
|
||||
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
|
||||
_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_1d_f32(StridedMemRefType<float, 1> *M);
|
||||
_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_2d_f32(StridedMemRefType<float, 2> *M);
|
||||
_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_3d_f32(StridedMemRefType<float, 3> *M);
|
||||
_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_4d_f32(StridedMemRefType<float, 4> *M);
|
||||
_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
|
||||
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
|
||||
_mlir_ciface_print_memref_vector_4x4xf32(
|
||||
StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
|
||||
|
||||
// Small runtime support "lib" for vector.print lowering.
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
#include <cinttypes>
|
||||
#include <cstdio>
|
||||
|
||||
extern "C" void
|
||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
||||
extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
|
||||
StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
|||
impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr))); \
|
||||
break
|
||||
|
||||
extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
|
||||
extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
|
||||
printUnrankedMemRefMetaData(std::cout, *M);
|
||||
int rank = M->rank;
|
||||
void *ptr = M->descriptor;
|
||||
|
@ -42,7 +42,7 @@ extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
|
||||
extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
|
||||
printUnrankedMemRefMetaData(std::cout, *M);
|
||||
int rank = M->rank;
|
||||
void *ptr = M->descriptor;
|
||||
|
@ -58,19 +58,31 @@ extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
|
||||
extern "C" void print_memref_f32(int64_t rank, void *ptr) {
|
||||
UnrankedMemRefType<float> descriptor;
|
||||
descriptor.rank = rank;
|
||||
descriptor.descriptor = ptr;
|
||||
_mlir_ciface_print_memref_f32(&descriptor);
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
|
||||
extern "C" void
|
||||
_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
extern "C" void print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
|
||||
extern "C" void
|
||||
_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
|
||||
extern "C" void
|
||||
_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
||||
extern "C" void
|
||||
_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,12 +17,13 @@ func @main() {
|
|||
%21 = constant 5 : i32
|
||||
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
|
||||
call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
|
||||
call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
|
||||
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
|
||||
%24 = constant 1.0 : f32
|
||||
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
|
||||
call @print_memref_1d_f32(%22) : (memref<?xf32>) -> ()
|
||||
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
|
||||
func @print_memref_1d_f32(memref<?xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
|
|
@ -96,11 +96,34 @@ void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
|
|||
std::fill_n(arg->data, count, value);
|
||||
mcuMemHostRegister(arg->data, count * sizeof(T));
|
||||
}
|
||||
extern "C" void
|
||||
mcuMemHostRegisterMemRef1dFloat(const MemRefType<float, 1> *arg) {
|
||||
mcuMemHostRegisterMemRef(arg, 1.23f);
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
|
||||
float *aligned, int64_t offset,
|
||||
int64_t size, int64_t stride) {
|
||||
MemRefType<float, 1> descriptor;
|
||||
descriptor.basePtr = allocated;
|
||||
descriptor.data = aligned;
|
||||
descriptor.offset = offset;
|
||||
descriptor.sizes[0] = size;
|
||||
descriptor.strides[0] = stride;
|
||||
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
|
||||
}
|
||||
extern "C" void
|
||||
mcuMemHostRegisterMemRef3dFloat(const MemRefType<float, 3> *arg) {
|
||||
mcuMemHostRegisterMemRef(arg, 1.23f);
|
||||
|
||||
extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
|
||||
float *aligned, int64_t offset,
|
||||
int64_t size0, int64_t size1,
|
||||
int64_t size2, int64_t stride0,
|
||||
int64_t stride1,
|
||||
int64_t stride2) {
|
||||
MemRefType<float, 3> descriptor;
|
||||
descriptor.basePtr = allocated;
|
||||
descriptor.data = aligned;
|
||||
descriptor.offset = offset;
|
||||
descriptor.sizes[0] = size0;
|
||||
descriptor.strides[0] = stride0;
|
||||
descriptor.sizes[1] = size1;
|
||||
descriptor.strides[1] = stride1;
|
||||
descriptor.sizes[2] = size2;
|
||||
descriptor.strides[2] = stride2;
|
||||
mcuMemHostRegisterMemRef(&descriptor, 1.23f);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue