[mlir][llvm] Pass struct results as parameter in c wrapper

Returning structs directly in LLVM does not necessarily align with the C ABI of
the platform. This might happen to work on Linux but for small structs this
breaks on Windows. With this change, the wrappers work platform independently.

Differential Revision: https://reviews.llvm.org/D98725
This commit is contained in:
Stephan Herhut 2021-03-17 12:16:30 +01:00
parent ecfa874531
commit 5837fdc4cc
4 changed files with 146 additions and 35 deletions

View File

@ -232,29 +232,40 @@ struct MemRefDescriptor {
};
```
Furthermore, we also rewrite function results to pointer parameters if the
rewritten function result has a struct type. The special result parameter is
added as the first parameter and is of pointer-to-struct type.
If enabled, the option will do the following. For _external_ functions declared
in the MLIR module.
1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
as usual.
1. Add a body to the original function (making it non-external) that
1. allocates a memref descriptor,
1. populates it, and
1. passes the pointer to it into the newly declared interface function,
as usual. Results are converted to a special argument if they are of struct
type.
2. Add a body to the original function (making it non-external) that
1. allocates memref descriptors,
2. populates them,
3. potentially allocates space for the result struct, and
4. passes the pointers to these into the newly declared interface function,
then
1. collects the result of the call and returns it to the caller.
5. collects the result of the call (potentially from the result struct),
and
6. returns it to the caller.
For (non-external) functions defined in the MLIR module.
1. Define a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
as usual.
1. Populate the body of the newly defined function with IR that
as usual. Results are converted to a special argument if they are of struct
type.
2. Populate the body of the newly defined function with IR that
1. loads descriptors from pointers;
1. unpacks descriptor into individual non-aggregate values;
1. passes these values into the original function;
1. collects the result of the call and returns it to the caller.
2. unpacks descriptor into individual non-aggregate values;
3. passes these values into the original function;
4. collects the results of the call and
5. either copies the results into the result struct or returns them to the
caller.
Examples:
@ -342,6 +353,49 @@ llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr) {
}
```
```mlir
func @foo(%arg0: memref<?x?xf32>) -> memref<?x?xf32> {
return %arg0 : memref<?x?xf32>
}
// Gets converted into the following
// (using type alias for brevity):
!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
array<2xi64>, array<2xi64>)>
!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
array<2xi64>, array<2xi64>)>>
// Function with unpacked arguments.
llvm.func @foo(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64)
-> !llvm.memref_2d {
%0 = llvm.mlir.undef : !llvm.memref_2d
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.memref_2d
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.memref_2d
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.memref_2d
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.memref_2d
%5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.memref_2d
%6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm.memref_2d
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.memref_2d
llvm.return %7 : !llvm.memref_2d
}
// Interface function callable from C.
llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr, %arg1: !llvm.memref_2d_ptr) {
%0 = llvm.load %arg1 : !llvm.memref_2d_ptr
%1 = llvm.extractvalue %0[0] : !llvm.memref_2d
%2 = llvm.extractvalue %0[1] : !llvm.memref_2d
%3 = llvm.extractvalue %0[2] : !llvm.memref_2d
%4 = llvm.extractvalue %0[3, 0] : !llvm.memref_2d
%5 = llvm.extractvalue %0[3, 1] : !llvm.memref_2d
%6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d
%7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d
%8 = llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
: (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, i64, i64) -> !llvm.memref_2d
llvm.store %8, %arg0 : !llvm.memref_2d_ptr
llvm.return
}
Rationale: Introducing auxiliary functions for C-compatible interfaces is
preferred to modifying the calling convention since it will minimize the effect
of C compatibility on intra-module calls or calls between MLIR-generated

View File

@ -116,8 +116,10 @@ public:
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
Type convertFunctionTypeCWrapper(FunctionType type);
/// pointers to memref descriptors for arguments. Also converts the return
/// type to a pointer argument if it is a struct. Returns true if this
/// was the case.
std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
/// Returns the data layout to use during and after conversion.
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }

View File

@ -253,8 +253,24 @@ Type LLVMTypeConverter::convertFunctionSignature(
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
std::pair<Type, bool>
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
SmallVector<Type, 4> inputs;
bool resultIsNowArg = false;
Type resultType = type.getNumResults() == 0
? LLVM::LLVMVoidType::get(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
// Struct types cannot be safely returned via C interface. Make this a
// pointer argument, instead.
inputs.push_back(LLVM::LLVMPointerType::get(structType));
resultType = LLVM::LLVMVoidType::get(&getContext());
resultIsNowArg = true;
}
for (Type t : type.getInputs()) {
auto converted = convertType(t);
@ -265,13 +281,7 @@ Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
inputs.push_back(converted);
}
Type resultType = type.getNumResults() == 0
? LLVM::LLVMVoidType::get(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMFunctionType::get(resultType, inputs);
return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
}
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
@ -1212,8 +1222,11 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. This function can be called from C
/// by passing a pointer to a C struct corresponding to a memref descriptor.
/// Similarly, returned memrefs are passed via pointers to a C struct that is
/// passed as additional argument.
/// Internally, the auxiliary function unpacks the descriptor into individual
/// components and forwards them to `newFuncOp`.
/// components and forwards them to `newFuncOp` and forwards the results to
/// the extra arguments.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
@ -1221,17 +1234,21 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
attributes);
Type wrapperFuncType;
bool resultIsNowArg;
std::tie(wrapperFuncType, resultIsNowArg) =
typeConverter.convertFunctionTypeCWrapper(type);
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
attributes);
wrapperFuncType, LLVM::Linkage::External, attributes);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
SmallVector<Value, 8> args;
size_t argOffset = resultIsNowArg ? 1 : 0;
for (auto &en : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(en.index());
Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
@ -1243,28 +1260,40 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
continue;
}
args.push_back(wrapperFuncOp.getArgument(en.index()));
args.push_back(arg);
}
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
if (resultIsNowArg) {
rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
wrapperFuncOp.getArgument(0));
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
} else {
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
}
}
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. Creates a body for the (external)
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
/// individual arguments into this descriptor and passes a pointer to it into
/// the auxiliary function. This auxiliary external function is now compatible
/// with functions defined in C using pointers to C structs corresponding to a
/// memref descriptor.
/// the auxiliary function. If the result of the function cannot be directly
/// returned, we write it to a special first argument that provides a pointer
/// to a corresponding struct. This auxiliary external function is now
/// compatible with functions defined in C using pointers to C structs
/// corresponding to a memref descriptor.
static void wrapExternalFunction(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
Type wrapperType =
Type wrapperType;
bool resultIsNowArg;
std::tie(wrapperType, resultIsNowArg) =
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
// This conversion can only fail if it could not convert one of the argument
// types. But since it has been applies to a non-wrapper function before, it
// types. But since it has been applied to a non-wrapper function before, it
// should have failed earlier and not reach this point at all.
assert(wrapperType && "unexpected type conversion failure");
@ -1285,6 +1314,17 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
args.reserve(type.getNumInputs());
ValueRange wrapperArgsRange(newFuncOp.getArguments());
if (resultIsNowArg) {
// Allocate the struct on the stack and pass the pointer.
Type resultType =
wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
args.push_back(result);
}
// Iterate over the inputs of the original function and pack values into
// memref descriptors if the original type is a memref.
for (auto &en : llvm::enumerate(type.getInputs())) {
@ -1322,8 +1362,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
if (resultIsNowArg) {
Value result = builder.create<LLVM::LoadOp>(loc, args.front());
builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
} else {
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}
}
namespace {

View File

@ -144,7 +144,7 @@ func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
}
// CHECK-LABEL: llvm.func @return_var_memref
func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { llvm.emit_c_interface } {
// Match the construction of the unranked descriptor.
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@ -177,6 +177,10 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
return %0 : memref<*xf32>
}
// Check that the result memref is passed as parameter
// CHECK-LABEL: @_mlir_ciface_return_var_memref
// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(i64, ptr<i8>)>>, %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
// CHECK-LABEL: llvm.func @return_two_var_memref_caller
func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
// Only check that we create two different descriptors using different
@ -206,7 +210,7 @@ func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
}
// CHECK-LABEL: llvm.func @return_two_var_memref
func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) attributes { llvm.emit_c_interface } {
// Match the construction of the unranked descriptor.
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@ -240,3 +244,8 @@ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*x
return %0, %0 : memref<*xf32>, memref<*xf32>
}
// Check that the result memrefs are passed as parameter
// CHECK-LABEL: @_mlir_ciface_return_two_var_memref
// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>>,
// CHECK-SAME: %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)