NFC: Split up Parser::parseAttribute into multiple smaller functions to improve readability.

PiperOrigin-RevId: 251158192
This commit is contained in:
River Riddle 2019-06-02 20:33:29 -07:00 committed by Mehdi Amini
parent c914976c72
commit b1393c2cd0
1 changed files with 288 additions and 291 deletions

View File

@ -224,10 +224,25 @@ public:
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
/// Parse an extended attribute.
Attribute parseExtendedAttribute(Type type);
Attribute parseExtendedAttr(Type type);
/// Parse a float attribute.
Attribute parseFloatAttr(Type type, bool isNegative);
/// Parse an integer attribute.
Attribute parseIntegerAttr(Type type, bool isSigned);
/// Parse an opaque elements attribute.
Attribute parseOpaqueElementsAttr();
/// Parse a sparse elements attribute.
Attribute parseSparseElementsAttr();
/// Parse a splat elements attribute.
Attribute parseSplatElementsAttr();
/// Parse a dense elements attribute.
DenseElementsAttr parseDenseElementsAttr(ShapedType type);
Attribute parseDenseElementsAttr();
DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
ShapedType parseElementsLiteralType();
@ -928,142 +943,7 @@ ParseResult Parser::parseXInDimensionList() {
///
Attribute Parser::parseAttribute(Type type) {
switch (getToken().getKind()) {
case Token::hash_identifier:
return parseExtendedAttribute(type);
case Token::kw_unit:
consumeToken(Token::kw_unit);
return builder.getUnitAttr();
case Token::kw_true:
consumeToken(Token::kw_true);
return builder.getBoolAttr(true);
case Token::kw_false:
consumeToken(Token::kw_false);
return builder.getBoolAttr(false);
case Token::floatliteral: {
auto val = getToken().getFloatingPointValue();
if (!val.hasValue())
return (emitError("floating point value too large for attribute"),
nullptr);
auto valTok = getToken().getLoc();
consumeToken(Token::floatliteral);
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to F64 when no type is specified.
type = builder.getF64Type();
}
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for specified type"),
nullptr);
return FloatAttr::getChecked(type, val.getValue(),
getEncodedSourceLocation(valTok));
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer constant out of range for attribute"),
nullptr);
consumeToken(Token::integer);
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to i64 if not type is specified.
type = builder.getIntegerType(64);
}
}
if (!type.isIntOrIndex())
return (emitError("integer value not valid for specified type"), nullptr);
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, val.getValue());
if (apInt != *val)
return emitError("integer constant out of range for attribute"), nullptr;
return builder.getIntegerAttr(type, apInt);
}
case Token::minus: {
consumeToken(Token::minus);
if (getToken().is(Token::integer)) {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer constant out of range for attribute"),
nullptr);
consumeToken(Token::integer);
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to i64 if not type is specified.
type = builder.getIntegerType(64);
}
}
if (!type.isIntOrIndex())
return (emitError("integer value not valid for type"), nullptr);
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, *val, /*isSigned=*/true);
if (apInt != *val)
return (emitError("integer constant out of range for attribute"),
nullptr);
return builder.getIntegerAttr(type, -apInt);
}
if (getToken().is(Token::floatliteral)) {
auto val = getToken().getFloatingPointValue();
if (!val.hasValue())
return (emitError("floating point value too large for attribute"),
nullptr);
auto valTok = getToken().getLoc();
consumeToken(Token::floatliteral);
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to F64 when no type is specified.
type = builder.getF64Type();
}
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for type"), nullptr);
return FloatAttr::getChecked(type, -val.getValue(),
getEncodedSourceLocation(valTok));
}
return (emitError("expected constant integer or floating point value"),
nullptr);
}
case Token::string: {
auto val = getToken().getStringValue();
consumeToken(Token::string);
return builder.getStringAttr(val);
}
case Token::l_brace: {
SmallVector<NamedAttribute, 4> elements;
if (parseAttributeDict(elements))
return nullptr;
return builder.getDictionaryAttr(elements);
}
case Token::l_square: {
consumeToken(Token::l_square);
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
// Parse an AffineMap or IntegerSet attribute.
case Token::l_paren: {
// Try to parse an affine map or an integer set reference.
AffineMap map;
@ -1076,170 +956,98 @@ Attribute Parser::parseAttribute(Type type) {
return builder.getIntegerSetAttr(set);
}
// Parse an array attribute.
case Token::l_square: {
consumeToken(Token::l_square);
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
// Parse a boolean attribute.
case Token::kw_false:
consumeToken(Token::kw_false);
return builder.getBoolAttr(false);
case Token::kw_true:
consumeToken(Token::kw_true);
return builder.getBoolAttr(true);
// Parse a dense elements attribute.
case Token::kw_dense:
return parseDenseElementsAttr();
// Parse a dictionary attribute.
case Token::l_brace: {
SmallVector<NamedAttribute, 4> elements;
if (parseAttributeDict(elements))
return nullptr;
return builder.getDictionaryAttr(elements);
}
// Parse an extended attribute, i.e. alias or dialect attribute.
case Token::hash_identifier:
return parseExtendedAttr(type);
// Parse floating point and integer attributes.
case Token::floatliteral:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseIntegerAttr(type, /*isSigned=*/false);
case Token::minus: {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseIntegerAttr(type, /*isSigned=*/true);
if (getToken().is(Token::floatliteral))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitError("expected constant integer or floating point value"),
nullptr);
}
// Parse a function attribute.
case Token::at_identifier: {
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
return builder.getFunctionAttr(nameStr.drop_front());
}
case Token::kw_opaque: {
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
if (getToken().getKind() != Token::string)
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
auto *dialect = builder.getContext()->getRegisteredDialect(name);
// TODO(shpeisman): Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
if (!dialect)
return (emitError("no registered dialect with namespace '" + name + "'"),
nullptr);
// Parse an opaque elements attribute.
case Token::kw_opaque:
return parseOpaqueElementsAttr();
consumeToken(Token::string);
if (parseToken(Token::comma, "expected ','"))
return nullptr;
// Parse a sparse elements attribute.
case Token::kw_sparse:
return parseSparseElementsAttr();
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
// Parse a splat elements attribute.
case Token::kw_splat:
return parseSplatElementsAttr();
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
if (getToken().getKind() != Token::string)
return (emitError("opaque string should start with '0x'"), nullptr);
// Parse a string attribute.
case Token::string: {
auto val = getToken().getStringValue();
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
return (emitError("opaque string should start with '0x'"), nullptr);
val = val.substr(2);
if (!std::all_of(val.begin(), val.end(),
[](char c) { return llvm::isHexDigit(c); })) {
return (emitError("opaque string only contains hex digits"), nullptr);
}
consumeToken(Token::string);
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
return builder.getStringAttr(val);
}
case Token::kw_splat: {
consumeToken(Token::kw_splat);
if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
switch (getToken().getKind()) {
case Token::floatliteral:
case Token::integer:
case Token::kw_false:
case Token::kw_true:
case Token::minus: {
auto scalar = parseAttribute(type.getElementType());
if (!scalar)
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getSplatElementsAttr(type, scalar);
}
default:
return (emitError("expected scalar constant inside tensor literal"),
nullptr);
}
}
case Token::kw_dense: {
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
// Parse a 'unit' attribute.
case Token::kw_unit:
consumeToken(Token::kw_unit);
return builder.getUnitAttr();
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
auto attr = parseDenseElementsAttr(type);
if (!attr)
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return attr;
}
case Token::kw_sparse: {
consumeToken(Token::kw_sparse);
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
switch (getToken().getKind()) {
case Token::l_square: {
/// Parse indices
auto indicesEltType = builder.getIntegerType(64);
auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
if (!indices)
return nullptr;
if (parseToken(Token::comma, "expected ','"))
return nullptr;
/// Parse values.
auto valuesEltType = type.getElementType();
auto values = parseDenseElementsAttrAsTensor(valuesEltType);
if (!values)
return nullptr;
/// Sanity check.
auto valuesType = values.getType();
if (valuesType.getRank() != 1) {
return (emitError("expected 1-d tensor for values"), nullptr);
}
auto indicesType = indices.getType();
auto sameShape = (indicesType.getRank() == 1) ||
(type.getRank() == indicesType.getDimSize(1));
auto sameElementNum =
indicesType.getDimSize(0) == valuesType.getDimSize(0);
if (!sameShape || !sameElementNum) {
emitError() << "expected shape ([" << type.getShape()
<< "]); inferred shape of indices literal (["
<< indicesType.getShape()
<< "]); inferred shape of values literal (["
<< valuesType.getShape() << "])";
return nullptr;
}
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
// Build the sparse elements attribute by the indices and values.
return builder.getSparseElementsAttr(
type, indices.cast<DenseIntElementsAttr>(), values);
}
default:
return (emitError("expected '[' to start sparse tensor literal"),
nullptr);
}
return (emitError("expected elements literal has a tensor or vector type"),
nullptr);
}
default: {
default:
// Parse a type attribute.
if (Type type = parseType())
return builder.getTypeAttr(type);
return nullptr;
}
}
}
/// Attribute dictionary.
@ -1289,7 +1097,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
/// attribute-alias ::= `#` alias-name
///
Attribute Parser::parseExtendedAttribute(Type type) {
Attribute Parser::parseExtendedAttr(Type type) {
Attribute attr = parseExtendedSymbol<Attribute>(
*this, Token::hash_identifier, state.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
@ -1313,6 +1121,183 @@ Attribute Parser::parseExtendedAttribute(Type type) {
return attr;
}
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
auto val = getToken().getFloatingPointValue();
if (!val.hasValue())
return (emitError("floating point value too large for attribute"), nullptr);
auto valTok = getToken().getLoc();
consumeToken(Token::floatliteral);
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
type = builder.getF64Type();
else if (!(type = parseType()))
return nullptr;
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for specified type"),
nullptr);
return FloatAttr::getChecked(type,
isNegative ? -val.getValue() : val.getValue(),
getEncodedSourceLocation(valTok));
}
/// Parse an integer attribute.
Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() ||
(isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
return (emitError("integer constant out of range for attribute"), nullptr);
consumeToken(Token::integer);
if (!type) {
// Default to i64 if not type is specified.
if (!consumeIf(Token::colon))
type = builder.getIntegerType(64);
else if (!(type = parseType()))
return nullptr;
}
if (!type.isIntOrIndex())
return (emitError("integer value not valid for specified type"), nullptr);
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
APInt apInt(width, *val, isSigned);
if (apInt != *val)
return (emitError("integer constant out of range for attribute"), nullptr);
return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
}
/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr() {
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
if (getToken().isNot(Token::string))
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
auto *dialect = builder.getContext()->getRegisteredDialect(name);
// TODO(shpeisman): Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
if (!dialect)
return (emitError("no registered dialect with namespace '" + name + "'"),
nullptr);
consumeToken(Token::string);
if (parseToken(Token::comma, "expected ','"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
if (getToken().getKind() != Token::string)
return (emitError("opaque string should start with '0x'"), nullptr);
auto val = getToken().getStringValue();
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
return (emitError("opaque string should start with '0x'"), nullptr);
val = val.substr(2);
if (!llvm::all_of(val, llvm::isHexDigit))
return (emitError("opaque string only contains hex digits"), nullptr);
consumeToken(Token::string);
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
}
/// Parse a sparse elements attribute.
Attribute Parser::parseSparseElementsAttr() {
consumeToken(Token::kw_sparse);
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
if (getToken().isNot(Token::l_square))
return emitError("expected '[' to start sparse tensor literal"), nullptr;
/// Parse indices
auto indicesEltType = builder.getIntegerType(64);
auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
if (!indices)
return nullptr;
if (parseToken(Token::comma, "expected ','"))
return nullptr;
/// Parse values.
auto valuesEltType = type.getElementType();
auto values = parseDenseElementsAttrAsTensor(valuesEltType);
if (!values)
return nullptr;
/// Sanity check.
auto valuesType = values.getType();
if (valuesType.getRank() != 1)
return (emitError("expected 1-d tensor for values"), nullptr);
auto indicesType = indices.getType();
auto sameShape = (indicesType.getRank() == 1) ||
(type.getRank() == indicesType.getDimSize(1));
auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
if (!sameShape || !sameElementNum) {
emitError() << "expected shape ([" << type.getShape()
<< "]); inferred shape of indices literal (["
<< indicesType.getShape()
<< "]); inferred shape of values literal (["
<< valuesType.getShape() << "])";
return nullptr;
}
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
// Build the sparse elements attribute by the indices and values.
return builder.getSparseElementsAttr(
type, indices.cast<DenseIntElementsAttr>(), values);
}
/// Parse a splat elements attribute.
Attribute Parser::parseSplatElementsAttr() {
consumeToken(Token::kw_splat);
if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
switch (getToken().getKind()) {
case Token::floatliteral:
case Token::integer:
case Token::kw_false:
case Token::kw_true:
case Token::minus: {
auto scalar = parseAttribute(type.getElementType());
if (!scalar)
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getSplatElementsAttr(type, scalar);
}
default:
return emitError("expected scalar constant inside tensor literal"), nullptr;
}
}
namespace {
class TensorLiteralParser {
public:
@ -1457,9 +1442,19 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
auto eltTy = type.getElementType();
TensorLiteralParser literalParser(*this, eltTy);
Attribute Parser::parseDenseElementsAttr() {
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
auto type = parseElementsLiteralType();
if (!type)
return nullptr;
if (parseToken(Token::comma, "expected ',' after elements literal type"))
return nullptr;
TensorLiteralParser literalParser(*this, type.getElementType());
if (literalParser.parse())
return nullptr;
@ -1470,6 +1465,9 @@ DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
return nullptr;
}
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getDenseElementsAttr(type, literalParser.getValues())
.cast<DenseElementsAttr>();
}
@ -1504,9 +1502,8 @@ ShapedType Parser::parseElementsLiteralType() {
return nullptr;
if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
return (
emitError("elements literal must be a ranked tensor or vector type"),
nullptr);
emitError("elements literal must be a ranked tensor or vector type");
return nullptr;
}
auto sType = type.cast<ShapedType>();