forked from OSchip/llvm-project
[mlir][llvm] Pass struct results as parameter in c wrapper
Returning structs directly in LLVM does not necessarily align with the C ABI of the platform. This might happen to work on Linux but for small structs this breaks on Windows. With this change, the wrappers work platform independently. Differential Revision: https://reviews.llvm.org/D98725
This commit is contained in:
parent
ecfa874531
commit
5837fdc4cc
|
@ -232,29 +232,40 @@ struct MemRefDescriptor {
|
|||
};
|
||||
```
|
||||
|
||||
Furthermore, we also rewrite function results to pointer parameters if the
|
||||
rewritten function result has a struct type. The special result parameter is
|
||||
added as the first parameter and is of pointer-to-struct type.
|
||||
|
||||
If enabled, the option will do the following. For _external_ functions declared
|
||||
in the MLIR module.
|
||||
|
||||
1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
|
||||
are converted to pointer-to-struct and the remaining arguments are converted
|
||||
as usual.
|
||||
1. Add a body to the original function (making it non-external) that
|
||||
1. allocates a memref descriptor,
|
||||
1. populates it, and
|
||||
1. passes the pointer to it into the newly declared interface function,
|
||||
as usual. Results are converted to a special argument if they are of struct
|
||||
type.
|
||||
2. Add a body to the original function (making it non-external) that
|
||||
1. allocates memref descriptors,
|
||||
2. populates them,
|
||||
3. potentially allocates space for the result struct, and
|
||||
4. passes the pointers to these into the newly declared interface function,
|
||||
then
|
||||
1. collects the result of the call and returns it to the caller.
|
||||
5. collects the result of the call (potentially from the result struct),
|
||||
and
|
||||
6. returns it to the caller.
|
||||
|
||||
For (non-external) functions defined in the MLIR module.
|
||||
|
||||
1. Define a new function `_mlir_ciface_<original name>` where memref arguments
|
||||
are converted to pointer-to-struct and the remaining arguments are converted
|
||||
as usual.
|
||||
1. Populate the body of the newly defined function with IR that
|
||||
as usual. Results are converted to a special argument if they are of struct
|
||||
type.
|
||||
2. Populate the body of the newly defined function with IR that
|
||||
1. loads descriptors from pointers;
|
||||
1. unpacks descriptor into individual non-aggregate values;
|
||||
1. passes these values into the original function;
|
||||
1. collects the result of the call and returns it to the caller.
|
||||
2. unpacks descriptor into individual non-aggregate values;
|
||||
3. passes these values into the original function;
|
||||
4. collects the results of the call and
|
||||
5. either copies the results into the result struct or returns them to the
|
||||
caller.
|
||||
|
||||
Examples:
|
||||
|
||||
|
@ -342,6 +353,49 @@ llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr) {
|
|||
}
|
||||
```
|
||||
|
||||
```mlir
|
||||
func @foo(%arg0: memref<?x?xf32>) -> memref<?x?xf32> {
|
||||
return %arg0 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// Gets converted into the following
|
||||
// (using type alias for brevity):
|
||||
!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
|
||||
array<2xi64>, array<2xi64>)>
|
||||
!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
|
||||
array<2xi64>, array<2xi64>)>>
|
||||
|
||||
// Function with unpacked arguments.
|
||||
llvm.func @foo(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64,
|
||||
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64)
|
||||
-> !llvm.memref_2d {
|
||||
%0 = llvm.mlir.undef : !llvm.memref_2d
|
||||
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.memref_2d
|
||||
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.memref_2d
|
||||
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.memref_2d
|
||||
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.memref_2d
|
||||
%5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.memref_2d
|
||||
%6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm.memref_2d
|
||||
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.memref_2d
|
||||
llvm.return %7 : !llvm.memref_2d
|
||||
}
|
||||
|
||||
// Interface function callable from C.
|
||||
llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr, %arg1: !llvm.memref_2d_ptr) {
|
||||
%0 = llvm.load %arg1 : !llvm.memref_2d_ptr
|
||||
%1 = llvm.extractvalue %0[0] : !llvm.memref_2d
|
||||
%2 = llvm.extractvalue %0[1] : !llvm.memref_2d
|
||||
%3 = llvm.extractvalue %0[2] : !llvm.memref_2d
|
||||
%4 = llvm.extractvalue %0[3, 0] : !llvm.memref_2d
|
||||
%5 = llvm.extractvalue %0[3, 1] : !llvm.memref_2d
|
||||
%6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d
|
||||
%7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d
|
||||
%8 = llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
|
||||
: (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, i64, i64) -> !llvm.memref_2d
|
||||
llvm.store %8, %arg0 : !llvm.memref_2d_ptr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
Rationale: Introducing auxiliary functions for C-compatible interfaces is
|
||||
preferred to modifying the calling convention since it will minimize the effect
|
||||
of C compatibility on intra-module calls or calls between MLIR-generated
|
||||
|
|
|
@ -116,8 +116,10 @@ public:
|
|||
OpBuilder &builder);
|
||||
|
||||
/// Converts the function type to a C-compatible format, in particular using
|
||||
/// pointers to memref descriptors for arguments.
|
||||
Type convertFunctionTypeCWrapper(FunctionType type);
|
||||
/// pointers to memref descriptors for arguments. Also converts the return
|
||||
/// type to a pointer argument if it is a struct. Returns true if this
|
||||
/// was the case.
|
||||
std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
|
||||
|
||||
/// Returns the data layout to use during and after conversion.
|
||||
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
|
||||
|
|
|
@ -253,8 +253,24 @@ Type LLVMTypeConverter::convertFunctionSignature(
|
|||
|
||||
/// Converts the function type to a C-compatible format, in particular using
|
||||
/// pointers to memref descriptors for arguments.
|
||||
Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
||||
std::pair<Type, bool>
|
||||
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
||||
SmallVector<Type, 4> inputs;
|
||||
bool resultIsNowArg = false;
|
||||
|
||||
Type resultType = type.getNumResults() == 0
|
||||
? LLVM::LLVMVoidType::get(&getContext())
|
||||
: unwrap(packFunctionResults(type.getResults()));
|
||||
if (!resultType)
|
||||
return {};
|
||||
|
||||
if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
|
||||
// Struct types cannot be safely returned via C interface. Make this a
|
||||
// pointer argument, instead.
|
||||
inputs.push_back(LLVM::LLVMPointerType::get(structType));
|
||||
resultType = LLVM::LLVMVoidType::get(&getContext());
|
||||
resultIsNowArg = true;
|
||||
}
|
||||
|
||||
for (Type t : type.getInputs()) {
|
||||
auto converted = convertType(t);
|
||||
|
@ -265,13 +281,7 @@ Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
|||
inputs.push_back(converted);
|
||||
}
|
||||
|
||||
Type resultType = type.getNumResults() == 0
|
||||
? LLVM::LLVMVoidType::get(&getContext())
|
||||
: unwrap(packFunctionResults(type.getResults()));
|
||||
if (!resultType)
|
||||
return {};
|
||||
|
||||
return LLVM::LLVMFunctionType::get(resultType, inputs);
|
||||
return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
|
||||
}
|
||||
|
||||
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
|
||||
|
@ -1212,8 +1222,11 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
|||
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
||||
/// arguments instead of unpacked arguments. This function can be called from C
|
||||
/// by passing a pointer to a C struct corresponding to a memref descriptor.
|
||||
/// Similarly, returned memrefs are passed via pointers to a C struct that is
|
||||
/// passed as additional argument.
|
||||
/// Internally, the auxiliary function unpacks the descriptor into individual
|
||||
/// components and forwards them to `newFuncOp`.
|
||||
/// components and forwards them to `newFuncOp` and forwards the results to
|
||||
/// the extra arguments.
|
||||
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
||||
|
@ -1221,17 +1234,21 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
|||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
|
||||
attributes);
|
||||
Type wrapperFuncType;
|
||||
bool resultIsNowArg;
|
||||
std::tie(wrapperFuncType, resultIsNowArg) =
|
||||
typeConverter.convertFunctionTypeCWrapper(type);
|
||||
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
|
||||
attributes);
|
||||
wrapperFuncType, LLVM::Linkage::External, attributes);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
|
||||
|
||||
SmallVector<Value, 8> args;
|
||||
size_t argOffset = resultIsNowArg ? 1 : 0;
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
Value arg = wrapperFuncOp.getArgument(en.index());
|
||||
Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
|
||||
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
||||
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
|
||||
|
@ -1243,28 +1260,40 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
|||
continue;
|
||||
}
|
||||
|
||||
args.push_back(wrapperFuncOp.getArgument(en.index()));
|
||||
args.push_back(arg);
|
||||
}
|
||||
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
|
||||
|
||||
if (resultIsNowArg) {
|
||||
rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
|
||||
wrapperFuncOp.getArgument(0));
|
||||
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
|
||||
} else {
|
||||
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
||||
/// arguments instead of unpacked arguments. Creates a body for the (external)
|
||||
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
|
||||
/// individual arguments into this descriptor and passes a pointer to it into
|
||||
/// the auxiliary function. This auxiliary external function is now compatible
|
||||
/// with functions defined in C using pointers to C structs corresponding to a
|
||||
/// memref descriptor.
|
||||
/// the auxiliary function. If the result of the function cannot be directly
|
||||
/// returned, we write it to a special first argument that provides a pointer
|
||||
/// to a corresponding struct. This auxiliary external function is now
|
||||
/// compatible with functions defined in C using pointers to C structs
|
||||
/// corresponding to a memref descriptor.
|
||||
static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
Type wrapperType =
|
||||
Type wrapperType;
|
||||
bool resultIsNowArg;
|
||||
std::tie(wrapperType, resultIsNowArg) =
|
||||
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
|
||||
// This conversion can only fail if it could not convert one of the argument
|
||||
// types. But since it has been applies to a non-wrapper function before, it
|
||||
// types. But since it has been applied to a non-wrapper function before, it
|
||||
// should have failed earlier and not reach this point at all.
|
||||
assert(wrapperType && "unexpected type conversion failure");
|
||||
|
||||
|
@ -1285,6 +1314,17 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
|||
args.reserve(type.getNumInputs());
|
||||
ValueRange wrapperArgsRange(newFuncOp.getArguments());
|
||||
|
||||
if (resultIsNowArg) {
|
||||
// Allocate the struct on the stack and pass the pointer.
|
||||
Type resultType =
|
||||
wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
|
||||
Value one = builder.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.convertType(builder.getIndexType()),
|
||||
builder.getIntegerAttr(builder.getIndexType(), 1));
|
||||
Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
|
||||
args.push_back(result);
|
||||
}
|
||||
|
||||
// Iterate over the inputs of the original function and pack values into
|
||||
// memref descriptors if the original type is a memref.
|
||||
for (auto &en : llvm::enumerate(type.getInputs())) {
|
||||
|
@ -1322,8 +1362,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
|||
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
|
||||
|
||||
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
|
||||
|
||||
if (resultIsNowArg) {
|
||||
Value result = builder.create<LLVM::LoadOp>(loc, args.front());
|
||||
builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
|
||||
} else {
|
||||
builder.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
|
@ -144,7 +144,7 @@ func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @return_var_memref
|
||||
func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
|
||||
func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { llvm.emit_c_interface } {
|
||||
// Match the construction of the unranked descriptor.
|
||||
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
|
||||
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
|
||||
|
@ -177,6 +177,10 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
|
|||
return %0 : memref<*xf32>
|
||||
}
|
||||
|
||||
// Check that the result memref is passed as parameter
|
||||
// CHECK-LABEL: @_mlir_ciface_return_var_memref
|
||||
// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(i64, ptr<i8>)>>, %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
|
||||
|
||||
// CHECK-LABEL: llvm.func @return_two_var_memref_caller
|
||||
func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
|
||||
// Only check that we create two different descriptors using different
|
||||
|
@ -206,7 +210,7 @@ func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @return_two_var_memref
|
||||
func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
|
||||
func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) attributes { llvm.emit_c_interface } {
|
||||
// Match the construction of the unranked descriptor.
|
||||
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
|
||||
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
|
||||
|
@ -240,3 +244,8 @@ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*x
|
|||
return %0, %0 : memref<*xf32>, memref<*xf32>
|
||||
}
|
||||
|
||||
// Check that the result memrefs are passed as parameter
|
||||
// CHECK-LABEL: @_mlir_ciface_return_two_var_memref
|
||||
// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>>,
|
||||
// CHECK-SAME: %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
|
||||
|
||||
|
|
Loading…
Reference in New Issue