[Flang][mlir] add a band-aid to support the creation of mutually recursive types when lowering to LLVM IR

Summary:
This is a temporary implementation to support Flang.  The LLVM-IR parser
will need to be extended in some way to support recursive types.  The
exact approach here is still a work-in-progress.

Unfortunately, this won't pass roundtrip testing yet. Adding a comment
to the test file as a reminder.

Differential Revision: https://reviews.llvm.org/D72542
This commit is contained in:
Eric Schweitz 2020-01-17 21:07:58 +01:00 committed by Alex Zinenko
parent 44aaca3de4
commit 37e2560d3d
3 changed files with 82 additions and 0 deletions

View File

@ -138,6 +138,47 @@ public:
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
static LLVMType getVoidTy(LLVMDialect *dialect);
// Creation and setting of LLVM's identified struct types
static LLVMType createStructTy(LLVMDialect *dialect,
ArrayRef<LLVMType> elements,
Optional<StringRef> name,
bool isPacked = false);
static LLVMType createStructTy(LLVMDialect *dialect,
Optional<StringRef> name) {
return createStructTy(dialect, llvm::None, name);
}
static LLVMType createStructTy(ArrayRef<LLVMType> elements,
Optional<StringRef> name,
bool isPacked = false) {
assert(!elements.empty() &&
"This method may not be invoked with an empty list");
LLVMType ele0 = elements.front();
return createStructTy(&ele0.getDialect(), elements, name, isPacked);
}
template <typename... Args>
static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>
createStructTy(StringRef name, LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
Optional<StringRef> opt_name(name);
return createStructTy(&elt1.getDialect(), fields, opt_name);
}
static LLVMType setStructTyBody(LLVMType structType,
ArrayRef<LLVMType> elements,
bool isPacked = false);
template <typename... Args>
static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>
setStructTyBody(LLVMType structType, LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
return setStructTyBody(structType, fields);
}
private:
friend LLVMDialect;

View File

@ -1641,6 +1641,35 @@ LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
isPacked);
});
}
inline static SmallVector<llvm::Type *, 8>
toUnderlyingTypes(ArrayRef<LLVMType> elements) {
SmallVector<llvm::Type *, 8> llvmElements;
for (auto elt : elements)
llvmElements.push_back(elt.getUnderlyingType());
return llvmElements;
}
LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
ArrayRef<LLVMType> elements,
Optional<StringRef> name, bool isPacked) {
StringRef sr = name.hasValue() ? *name : "";
SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
return getLocked(dialect, [=] {
auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr);
if (!llvmElements.empty())
rv->setBody(llvmElements, isPacked);
return rv;
});
}
LLVMType LLVMType::setStructTyBody(LLVMType structType,
ArrayRef<LLVMType> elements, bool isPacked) {
llvm::StructType *st =
llvm::cast<llvm::StructType>(structType.getUnderlyingType());
SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
return getLocked(&structType.getDialect(), [=] {
st->setBody(llvmElements, isPacked);
return st;
});
}
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
// Lock access to the dialect as this may modify the LLVM context.
return getLocked(&elementType.getDialect(), [=] {

View File

@ -382,3 +382,15 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
}
// -----
// FIXME: the LLVM-IR dialect should parse mutually recursive types
// CHECK-LABEL: @recursive_type
// expected-error@+1 {{expected end of string}}
llvm.func @recursive_type(%a : !llvm<"%a = type { %a* }">) ->
!llvm<"%a = type { %a* }"> {
llvm.return %a : !llvm<"%a = type { %a* }">
}