[mlir] Use thread_local stack in LLVM dialect type parsing and printing

LLVM dialect type parsing and printing have been using a local stack object
forwarded between recursive functions responsible for parsing or printing
specific types. This stack is necessary to intercept (mutually) recursive
structure types and avoid inifinite recursion. This approach works only thanks
to the closedness of the LLVM dialect type system: types that don't belong to
the dialect are not allowed. Switch the approach to using a `thread_local`
stack inside the functions parsing the structure types. This makes the code
slightly cleaner by avoiding the need to pass the stack object around and, more
importantly, makes it possible to reconsider the closedness of the LLVM dialect
type system. As a nice side effect of this change, container LLVM dialect types
now support type aliases in their body (although it is currently impossible to
also use the alises when printing).

Depends On D93713

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93714
This commit is contained in:
Alex Zinenko 2021-01-05 17:01:41 +01:00
parent 351a45ca73
commit 74438eff51
3 changed files with 188 additions and 152 deletions

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
@ -19,8 +20,14 @@ using namespace mlir::LLVM;
// Printing.
//===----------------------------------------------------------------------===//
static void printTypeImpl(llvm::raw_ostream &os, Type type,
llvm::SetVector<StringRef> &stack);
/// If the given type is compatible with the LLVM dialect, prints it using
/// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
/// prints it as usual.
static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
if (isCompatibleType(type))
return mlir::LLVM::detail::printType(type, printer);
printer.printType(type);
}
/// Returns the keyword to use for the given type.
static StringRef getTypeKeyword(Type type) {
@ -48,76 +55,79 @@ static StringRef getTypeKeyword(Type type) {
});
}
/// Prints the body of a structure type. Uses `stack` to avoid printing
/// recursive structs indefinitely.
static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
llvm::SetVector<StringRef> &stack) {
/// Prints a structure type. Keeps track of known struct names to handle self-
/// or mutually-referring structs without falling into infinite recursion.
static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
// This keeps track of the names of identified structure types that are
// currently being printed. Since such types can refer themselves, this
// tracking is necessary to stop the recursion: the current function may be
// called recursively from DialectAsmPrinter::printType after the appropriate
// dispatch. We maintain the invariant of this storage being modified
// exclusively in this function, and at most one name being added per call.
// TODO: consider having such functionality inside DialectAsmPrinter.
thread_local llvm::SetVector<StringRef> knownStructNames;
unsigned stackSize = knownStructNames.size();
(void)stackSize;
auto guard = llvm::make_scope_exit([&]() {
assert(knownStructNames.size() == stackSize &&
"malformed identified stack when printing recursive structs");
});
printer << "<";
if (type.isIdentified()) {
printer << '"' << type.getName() << '"';
// If we are printing a reference to one of the enclosing structs, just
// print the name and stop to avoid infinitely long output.
if (knownStructNames.count(type.getName())) {
printer << '>';
return;
}
printer << ", ";
}
if (type.isIdentified() && type.isOpaque()) {
os << "opaque";
printer << "opaque>";
return;
}
if (type.isPacked())
os << "packed ";
printer << "packed ";
// Put the current type on stack to avoid infinite recursion.
os << '(';
printer << '(';
if (type.isIdentified())
stack.insert(type.getName());
llvm::interleaveComma(type.getBody(), os, [&](Type subtype) {
printTypeImpl(os, subtype, stack);
});
knownStructNames.insert(type.getName());
llvm::interleaveComma(type.getBody(), printer.getStream(),
[&](Type subtype) { dispatchPrint(printer, subtype); });
if (type.isIdentified())
stack.pop_back();
os << ')';
}
/// Prints a structure type. Uses `stack` to keep track of the identifiers of
/// the structs being printed. Checks if the identifier of a struct is contained
/// in `stack`, i.e. whether a self-reference to a recursive stack is being
/// printed, and only prints the name to avoid infinite recursion.
static void printStructType(llvm::raw_ostream &os, LLVMStructType type,
llvm::SetVector<StringRef> &stack) {
os << "<";
if (type.isIdentified()) {
os << '"' << type.getName() << '"';
// If we are printing a reference to one of the enclosing structs, just
// print the name and stop to avoid infinitely long output.
if (stack.count(type.getName())) {
os << '>';
return;
}
os << ", ";
}
printStructTypeBody(os, type, stack);
os << '>';
knownStructNames.pop_back();
printer << ')';
printer << '>';
}
/// Prints a type containing a fixed number of elements.
template <typename TypeTy>
static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type,
llvm::SetVector<StringRef> &stack) {
os << '<' << type.getNumElements() << " x ";
printTypeImpl(os, type.getElementType(), stack);
os << '>';
static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
printer << '<' << type.getNumElements() << " x ";
dispatchPrint(printer, type.getElementType());
printer << '>';
}
/// Prints a function type.
static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
llvm::SetVector<StringRef> &stack) {
os << '<';
printTypeImpl(os, funcType.getReturnType(), stack);
os << " (";
llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) {
printTypeImpl(os, subtype, stack);
});
static void printFunctionType(DialectAsmPrinter &printer,
LLVMFunctionType funcType) {
printer << '<';
dispatchPrint(printer, funcType.getReturnType());
printer << " (";
llvm::interleaveComma(
funcType.getParams(), printer.getStream(),
[&printer](Type subtype) { dispatchPrint(printer, subtype); });
if (funcType.isVarArg()) {
if (funcType.getNumParams() != 0)
os << ", ";
os << "...";
printer << ", ";
printer << "...";
}
os << ")>";
printer << ")>";
}
/// Prints the given LLVM dialect type recursively. This leverages closedness of
@ -129,75 +139,59 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
/// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
/// ptr<struct<"b", (ptr<struct<"c">>)>>)>
/// note that "b" is printed twice.
static void printTypeImpl(llvm::raw_ostream &os, Type type,
llvm::SetVector<StringRef> &stack) {
void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
if (!type) {
os << "<<NULL-TYPE>>";
printer << "<<NULL-TYPE>>";
return;
}
os << getTypeKeyword(type);
printer << getTypeKeyword(type);
if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
os << intType.getBitWidth();
printer << intType.getBitWidth();
return;
}
if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
os << '<';
printTypeImpl(os, ptrType.getElementType(), stack);
printer << '<';
dispatchPrint(printer, ptrType.getElementType());
if (ptrType.getAddressSpace() != 0)
os << ", " << ptrType.getAddressSpace();
os << '>';
printer << ", " << ptrType.getAddressSpace();
printer << '>';
return;
}
if (auto arrayType = type.dyn_cast<LLVMArrayType>())
return printArrayOrVectorType(os, arrayType, stack);
return printArrayOrVectorType(printer, arrayType);
if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
return printArrayOrVectorType(os, vectorType, stack);
return printArrayOrVectorType(printer, vectorType);
if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
os << "<? x " << vectorType.getMinNumElements() << " x ";
printTypeImpl(os, vectorType.getElementType(), stack);
os << '>';
printer << "<? x " << vectorType.getMinNumElements() << " x ";
dispatchPrint(printer, vectorType.getElementType());
printer << '>';
return;
}
if (auto structType = type.dyn_cast<LLVMStructType>())
return printStructType(os, structType, stack);
return printStructType(printer, structType);
if (auto funcType = type.dyn_cast<LLVMFunctionType>())
return printFunctionType(os, funcType, stack);
}
void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
llvm::SetVector<StringRef> stack;
return printTypeImpl(printer.getStream(), type, stack);
return printFunctionType(printer, funcType);
}
//===----------------------------------------------------------------------===//
// Parsing.
//===----------------------------------------------------------------------===//
static Type parseTypeImpl(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack);
/// Helper to be chained with other parsing functions.
static ParseResult parseTypeImpl(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack,
Type &result) {
result = parseTypeImpl(parser, stack);
return success(result != nullptr);
}
static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
Type returnType;
if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
return LLVMFunctionType();
@ -219,9 +213,10 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
/*isVarArg=*/true);
}
argTypes.push_back(parseTypeImpl(parser, stack));
if (!argTypes.back())
Type arg;
if (dispatchParse(parser, arg))
return LLVMFunctionType();
argTypes.push_back(arg);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
@ -232,11 +227,10 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
static LLVMPointerType parsePointerType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
Type elementType;
if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
if (parser.parseLess() || dispatchParse(parser, elementType))
return LLVMPointerType();
unsigned addressSpace = 0;
@ -251,15 +245,14 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser,
/// Parses an LLVM dialect vector type.
/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
/// Supports both fixed and scalable vectors.
static LLVMVectorType parseVectorType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
static LLVMVectorType parseVectorType(DialectAsmParser &parser) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
dispatchParse(parser, elementType) || parser.parseGreater())
return LLVMVectorType();
// We parsed a generic dimension list, but vectors only support two forms:
@ -282,15 +275,14 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
/// Parses an LLVM dialect array type.
/// llvm-type ::= `array<` integer `x` llvm-type `>`
static LLVMArrayType parseArrayType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
dispatchParse(parser, elementType) || parser.parseGreater())
return LLVMArrayType();
if (dims.size() != 1) {
@ -302,13 +294,11 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
}
/// Attempts to set the body of an identified structure type. Reports a parsing
/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
/// types printed in the error message look like they did when parsed.
/// error at `subtypesLoc` in case of failure.
static LLVMStructType trySetStructBody(LLVMStructType type,
ArrayRef<Type> subtypes, bool isPacked,
DialectAsmParser &parser,
llvm::SMLoc subtypesLoc,
llvm::SetVector<StringRef> &stack) {
llvm::SMLoc subtypesLoc) {
for (Type t : subtypes) {
if (!LLVMStructType::isValidElementType(t)) {
parser.emitError(subtypesLoc)
@ -320,12 +310,8 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
if (succeeded(type.setBody(subtypes, isPacked)))
return type;
std::string currentBody;
llvm::raw_string_ostream currentBodyStream(currentBody);
printStructTypeBody(currentBodyStream, type, stack);
auto diag = parser.emitError(subtypesLoc)
<< "identified type already used with a different body";
diag.attachNote() << "existing body: " << currentBodyStream.str();
parser.emitError(subtypesLoc)
<< "identified type already used with a different body";
return LLVMStructType();
}
@ -334,8 +320,22 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
/// `(` llvm-type-list `)` `>`
/// | `struct<` string-literal `>`
/// | `struct<` string-literal `, opaque>`
static LLVMStructType parseStructType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
static LLVMStructType parseStructType(DialectAsmParser &parser) {
// This keeps track of the names of identified structure types that are
// currently being parsed. Since such types can refer themselves, this
// tracking is necessary to stop the recursion: the current function may be
// called recursively from DialectAsmParser::parseType after the appropriate
// dispatch. We maintain the invariant of this storage being modified
// exclusively in this function, and at most one name being added per call.
// TODO: consider having such functionality inside DialectAsmParser.
thread_local llvm::SetVector<StringRef> knownStructNames;
unsigned stackSize = knownStructNames.size();
(void)stackSize;
auto guard = llvm::make_scope_exit([&]() {
assert(knownStructNames.size() == stackSize &&
"malformed identified stack when parsing recursive structs");
});
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (failed(parser.parseLess()))
@ -347,7 +347,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
StringRef name;
bool isIdentified = succeeded(parser.parseOptionalString(&name));
if (isIdentified) {
if (stack.count(name)) {
if (knownStructNames.count(name)) {
if (failed(parser.parseGreater()))
return LLVMStructType();
return LLVMStructType::getIdentifiedChecked(loc, name);
@ -384,7 +384,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack);
return trySetStructBody(type, {}, isPacked, parser, kwLoc);
}
// Parse subtypes. For identified structs, put the identifier of the struct on
@ -393,13 +393,13 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
do {
if (isIdentified)
stack.insert(name);
Type type = parseTypeImpl(parser, stack);
if (!type)
knownStructNames.insert(name);
Type type;
if (dispatchParse(parser, type))
return LLVMStructType();
subtypes.push_back(type);
if (isIdentified)
stack.pop_back();
knownStructNames.pop_back();
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen() || parser.parseGreater())
@ -409,30 +409,30 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack);
return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
}
/// Parses one of the LLVM dialect types.
static Type parseTypeImpl(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
// Special case for integers (i[1-9][0-9]*) that are literals rather than
// keywords for the parser, so they are not caught by the main dispatch below.
// Try parsing it a built-in integer type instead.
Type maybeIntegerType;
MLIRContext *ctx = parser.getBuilder().getContext();
/// Parses a type appearing inside another LLVM dialect-compatible type. This
/// will try to parse any type in full form (including types with the `!llvm`
/// prefix), and on failure fall back to parsing the short-hand version of the
/// LLVM dialect types without the `!llvm` prefix.
static Type dispatchParse(DialectAsmParser &parser) {
Type type;
llvm::SMLoc keyLoc = parser.getCurrentLocation();
Location loc = parser.getEncodedSourceLoc(keyLoc);
OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
if (result.hasValue()) {
if (failed(*result))
OptionalParseResult parseResult = parser.parseOptionalType(type);
if (parseResult.hasValue()) {
if (failed(*parseResult))
return Type();
if (!maybeIntegerType.isSignlessInteger()) {
parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
return Type();
}
return LLVMIntegerType::getChecked(
loc, maybeIntegerType.getIntOrFloatBitWidth());
// Special case for integers (i[1-9][0-9]*) that are literals rather than
// keywords for the parser, so they are not caught by the main dispatch
// below. Try parsing it a built-in integer type instead.
auto intType = type.dyn_cast<IntegerType>();
if (!intType || !intType.isSignless())
return type;
return LLVMIntegerType::getChecked(loc, intType.getWidth());
}
// Dispatch to concrete functions.
@ -440,6 +440,7 @@ static Type parseTypeImpl(DialectAsmParser &parser,
if (failed(parser.parseKeyword(&key)))
return Type();
MLIRContext *ctx = parser.getBuilder().getContext();
return StringSwitch<function_ref<Type()>>(key)
.Case("void", [&] { return LLVMVoidType::get(ctx); })
.Case("half", [&] { return LLVMHalfType::get(ctx); })
@ -453,18 +454,32 @@ static Type parseTypeImpl(DialectAsmParser &parser,
.Case("token", [&] { return LLVMTokenType::get(ctx); })
.Case("label", [&] { return LLVMLabelType::get(ctx); })
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
.Case("func", [&] { return parseFunctionType(parser, stack); })
.Case("ptr", [&] { return parsePointerType(parser, stack); })
.Case("vec", [&] { return parseVectorType(parser, stack); })
.Case("array", [&] { return parseArrayType(parser, stack); })
.Case("struct", [&] { return parseStructType(parser, stack); })
.Case("func", [&] { return parseFunctionType(parser); })
.Case("ptr", [&] { return parsePointerType(parser); })
.Case("vec", [&] { return parseVectorType(parser); })
.Case("array", [&] { return parseArrayType(parser); })
.Case("struct", [&] { return parseStructType(parser); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
return Type();
})();
}
Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
llvm::SetVector<StringRef> stack;
return parseTypeImpl(parser, stack);
/// Helper to use in parse lists.
static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
type = dispatchParse(parser);
return success(type != nullptr);
}
/// Parses one of the LLVM dialect types.
Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
llvm::SMLoc loc = parser.getCurrentLocation();
Type type = dispatchParse(parser);
if (!type)
return type;
if (!isCompatibleType(type)) {
parser.emitError(loc) << "unexpected type, expected i* or keyword";
return nullptr;
}
return type;
}

View File

@ -30,8 +30,7 @@ func @void_pointer() {
func @repeated_struct_name() {
"some.op"() : () -> !llvm.struct<"a", (ptr<struct<"a">>)>
// expected-error @+2 {{identified type already used with a different body}}
// expected-note @+1 {{existing body: (ptr<struct<"a">>)}}
// expected-error @+1 {{identified type already used with a different body}}
"some.op"() : () -> !llvm.struct<"a", (i32)>
}
@ -39,8 +38,7 @@ func @repeated_struct_name() {
func @repeated_struct_name_packed() {
"some.op"() : () -> !llvm.struct<"a", packed (i32)>
// expected-error @+2 {{identified type already used with a different body}}
// expected-note @+1 {{existing body: packed (i32)}}
// expected-error @+1 {{identified type already used with a different body}}
"some.op"() : () -> !llvm.struct<"a", (i32)>
}
@ -48,8 +46,7 @@ func @repeated_struct_name_packed() {
func @repeated_struct_opaque() {
"some.op"() : () -> !llvm.struct<"a", opaque>
// expected-error @+2 {{identified type already used with a different body}}
// expected-note @+1 {{existing body: opaque}}
// expected-error @+1 {{identified type already used with a different body}}
"some.op"() : () -> !llvm.struct<"a", ()>
}
@ -57,8 +54,7 @@ func @repeated_struct_opaque() {
func @repeated_struct_opaque_non_empty() {
"some.op"() : () -> !llvm.struct<"a", opaque>
// expected-error @+2 {{identified type already used with a different body}}
// expected-note @+1 {{existing body: opaque}}
// expected-error @+1 {{identified type already used with a different body}}
"some.op"() : () -> !llvm.struct<"a", (i32, i32)>
}
@ -95,8 +91,7 @@ func @unexpected_type() {
func @explicitly_opaque_struct() {
"some.op"() : () -> !llvm.struct<"a", opaque>
// expected-error @+2 {{identified type already used with a different body}}
// expected-note @+1 {{existing body: opaque}}
// expected-error @+1 {{identified type already used with a different body}}
"some.op"() : () -> !llvm.struct<"a", ()>
}

View File

@ -182,3 +182,29 @@ func @identified_struct() {
return
}
func @verbose() {
// CHECK: !llvm.struct<(i64, struct<(float)>)>
"some.op"() : () -> !llvm.struct<(!llvm.i64, !llvm.struct<(!llvm.float)>)>
return
}
// -----
// Check that type aliases can be used inside LLVM dialect types. Note that
// currently they are _not_ printed back as this would require
// DialectAsmPrinter to have a mechanism for querying the presence and
// usability of an alias outside of its `printType` method.
!baz = type !llvm.i64
!qux = type !llvm.struct<(!baz)>
!rec = type !llvm.struct<"a", (ptr<struct<"a">>)>
// CHECK: aliases
llvm.func @aliases() {
// CHECK: !llvm.struct<(i32, float, struct<(i64)>)>
"some.op"() : () -> !llvm.struct<(i32, float, !qux)>
// CHECK: !llvm.struct<"a", (ptr<struct<"a">>)>
"some.op"() : () -> !rec
llvm.return
}