forked from OSchip/llvm-project
NFC: Split up Parser::parseAttribute into multiple smaller functions to improve readability.
PiperOrigin-RevId: 251158192
This commit is contained in:
parent
c914976c72
commit
b1393c2cd0
|
@ -224,10 +224,25 @@ public:
|
|||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||
|
||||
/// Parse an extended attribute.
|
||||
Attribute parseExtendedAttribute(Type type);
|
||||
Attribute parseExtendedAttr(Type type);
|
||||
|
||||
/// Parse a float attribute.
|
||||
Attribute parseFloatAttr(Type type, bool isNegative);
|
||||
|
||||
/// Parse an integer attribute.
|
||||
Attribute parseIntegerAttr(Type type, bool isSigned);
|
||||
|
||||
/// Parse an opaque elements attribute.
|
||||
Attribute parseOpaqueElementsAttr();
|
||||
|
||||
/// Parse a sparse elements attribute.
|
||||
Attribute parseSparseElementsAttr();
|
||||
|
||||
/// Parse a splat elements attribute.
|
||||
Attribute parseSplatElementsAttr();
|
||||
|
||||
/// Parse a dense elements attribute.
|
||||
DenseElementsAttr parseDenseElementsAttr(ShapedType type);
|
||||
Attribute parseDenseElementsAttr();
|
||||
DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
|
||||
ShapedType parseElementsLiteralType();
|
||||
|
||||
|
@ -928,142 +943,7 @@ ParseResult Parser::parseXInDimensionList() {
|
|||
///
|
||||
Attribute Parser::parseAttribute(Type type) {
|
||||
switch (getToken().getKind()) {
|
||||
case Token::hash_identifier:
|
||||
return parseExtendedAttribute(type);
|
||||
case Token::kw_unit:
|
||||
consumeToken(Token::kw_unit);
|
||||
return builder.getUnitAttr();
|
||||
|
||||
case Token::kw_true:
|
||||
consumeToken(Token::kw_true);
|
||||
return builder.getBoolAttr(true);
|
||||
case Token::kw_false:
|
||||
consumeToken(Token::kw_false);
|
||||
return builder.getBoolAttr(false);
|
||||
|
||||
case Token::floatliteral: {
|
||||
auto val = getToken().getFloatingPointValue();
|
||||
if (!val.hasValue())
|
||||
return (emitError("floating point value too large for attribute"),
|
||||
nullptr);
|
||||
auto valTok = getToken().getLoc();
|
||||
consumeToken(Token::floatliteral);
|
||||
if (!type) {
|
||||
if (consumeIf(Token::colon)) {
|
||||
if (!(type = parseType()))
|
||||
return nullptr;
|
||||
} else {
|
||||
// Default to F64 when no type is specified.
|
||||
type = builder.getF64Type();
|
||||
}
|
||||
}
|
||||
if (!type.isa<FloatType>())
|
||||
return (emitError("floating point value not valid for specified type"),
|
||||
nullptr);
|
||||
return FloatAttr::getChecked(type, val.getValue(),
|
||||
getEncodedSourceLocation(valTok));
|
||||
}
|
||||
case Token::integer: {
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
if (!val.hasValue() || (int64_t)val.getValue() < 0)
|
||||
return (emitError("integer constant out of range for attribute"),
|
||||
nullptr);
|
||||
consumeToken(Token::integer);
|
||||
if (!type) {
|
||||
if (consumeIf(Token::colon)) {
|
||||
if (!(type = parseType()))
|
||||
return nullptr;
|
||||
} else {
|
||||
// Default to i64 if not type is specified.
|
||||
type = builder.getIntegerType(64);
|
||||
}
|
||||
}
|
||||
if (!type.isIntOrIndex())
|
||||
return (emitError("integer value not valid for specified type"), nullptr);
|
||||
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
|
||||
APInt apInt(width, val.getValue());
|
||||
if (apInt != *val)
|
||||
return emitError("integer constant out of range for attribute"), nullptr;
|
||||
return builder.getIntegerAttr(type, apInt);
|
||||
}
|
||||
|
||||
case Token::minus: {
|
||||
consumeToken(Token::minus);
|
||||
if (getToken().is(Token::integer)) {
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
|
||||
return (emitError("integer constant out of range for attribute"),
|
||||
nullptr);
|
||||
consumeToken(Token::integer);
|
||||
if (!type) {
|
||||
if (consumeIf(Token::colon)) {
|
||||
if (!(type = parseType()))
|
||||
return nullptr;
|
||||
} else {
|
||||
// Default to i64 if not type is specified.
|
||||
type = builder.getIntegerType(64);
|
||||
}
|
||||
}
|
||||
if (!type.isIntOrIndex())
|
||||
return (emitError("integer value not valid for type"), nullptr);
|
||||
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
|
||||
APInt apInt(width, *val, /*isSigned=*/true);
|
||||
if (apInt != *val)
|
||||
return (emitError("integer constant out of range for attribute"),
|
||||
nullptr);
|
||||
return builder.getIntegerAttr(type, -apInt);
|
||||
}
|
||||
if (getToken().is(Token::floatliteral)) {
|
||||
auto val = getToken().getFloatingPointValue();
|
||||
if (!val.hasValue())
|
||||
return (emitError("floating point value too large for attribute"),
|
||||
nullptr);
|
||||
auto valTok = getToken().getLoc();
|
||||
consumeToken(Token::floatliteral);
|
||||
if (!type) {
|
||||
if (consumeIf(Token::colon)) {
|
||||
if (!(type = parseType()))
|
||||
return nullptr;
|
||||
} else {
|
||||
// Default to F64 when no type is specified.
|
||||
type = builder.getF64Type();
|
||||
}
|
||||
}
|
||||
if (!type.isa<FloatType>())
|
||||
return (emitError("floating point value not valid for type"), nullptr);
|
||||
return FloatAttr::getChecked(type, -val.getValue(),
|
||||
getEncodedSourceLocation(valTok));
|
||||
}
|
||||
|
||||
return (emitError("expected constant integer or floating point value"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
case Token::string: {
|
||||
auto val = getToken().getStringValue();
|
||||
consumeToken(Token::string);
|
||||
return builder.getStringAttr(val);
|
||||
}
|
||||
|
||||
case Token::l_brace: {
|
||||
SmallVector<NamedAttribute, 4> elements;
|
||||
if (parseAttributeDict(elements))
|
||||
return nullptr;
|
||||
return builder.getDictionaryAttr(elements);
|
||||
}
|
||||
case Token::l_square: {
|
||||
consumeToken(Token::l_square);
|
||||
SmallVector<Attribute, 4> elements;
|
||||
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
elements.push_back(parseAttribute());
|
||||
return elements.back() ? success() : failure();
|
||||
};
|
||||
|
||||
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
|
||||
return nullptr;
|
||||
return builder.getArrayAttr(elements);
|
||||
}
|
||||
// Parse an AffineMap or IntegerSet attribute.
|
||||
case Token::l_paren: {
|
||||
// Try to parse an affine map or an integer set reference.
|
||||
AffineMap map;
|
||||
|
@ -1076,170 +956,98 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
return builder.getIntegerSetAttr(set);
|
||||
}
|
||||
|
||||
// Parse an array attribute.
|
||||
case Token::l_square: {
|
||||
consumeToken(Token::l_square);
|
||||
|
||||
SmallVector<Attribute, 4> elements;
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
elements.push_back(parseAttribute());
|
||||
return elements.back() ? success() : failure();
|
||||
};
|
||||
|
||||
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
|
||||
return nullptr;
|
||||
return builder.getArrayAttr(elements);
|
||||
}
|
||||
|
||||
// Parse a boolean attribute.
|
||||
case Token::kw_false:
|
||||
consumeToken(Token::kw_false);
|
||||
return builder.getBoolAttr(false);
|
||||
case Token::kw_true:
|
||||
consumeToken(Token::kw_true);
|
||||
return builder.getBoolAttr(true);
|
||||
|
||||
// Parse a dense elements attribute.
|
||||
case Token::kw_dense:
|
||||
return parseDenseElementsAttr();
|
||||
|
||||
// Parse a dictionary attribute.
|
||||
case Token::l_brace: {
|
||||
SmallVector<NamedAttribute, 4> elements;
|
||||
if (parseAttributeDict(elements))
|
||||
return nullptr;
|
||||
return builder.getDictionaryAttr(elements);
|
||||
}
|
||||
|
||||
// Parse an extended attribute, i.e. alias or dialect attribute.
|
||||
case Token::hash_identifier:
|
||||
return parseExtendedAttr(type);
|
||||
|
||||
// Parse floating point and integer attributes.
|
||||
case Token::floatliteral:
|
||||
return parseFloatAttr(type, /*isNegative=*/false);
|
||||
case Token::integer:
|
||||
return parseIntegerAttr(type, /*isSigned=*/false);
|
||||
case Token::minus: {
|
||||
consumeToken(Token::minus);
|
||||
if (getToken().is(Token::integer))
|
||||
return parseIntegerAttr(type, /*isSigned=*/true);
|
||||
if (getToken().is(Token::floatliteral))
|
||||
return parseFloatAttr(type, /*isNegative=*/true);
|
||||
|
||||
return (emitError("expected constant integer or floating point value"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Parse a function attribute.
|
||||
case Token::at_identifier: {
|
||||
auto nameStr = getTokenSpelling();
|
||||
consumeToken(Token::at_identifier);
|
||||
return builder.getFunctionAttr(nameStr.drop_front());
|
||||
}
|
||||
case Token::kw_opaque: {
|
||||
consumeToken(Token::kw_opaque);
|
||||
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
||||
return nullptr;
|
||||
|
||||
if (getToken().getKind() != Token::string)
|
||||
return (emitError("expected dialect namespace"), nullptr);
|
||||
auto name = getToken().getStringValue();
|
||||
auto *dialect = builder.getContext()->getRegisteredDialect(name);
|
||||
// TODO(shpeisman): Allow for having an unknown dialect on an opaque
|
||||
// attribute. Otherwise, it can't be roundtripped without having the dialect
|
||||
// registered.
|
||||
if (!dialect)
|
||||
return (emitError("no registered dialect with namespace '" + name + "'"),
|
||||
nullptr);
|
||||
// Parse an opaque elements attribute.
|
||||
case Token::kw_opaque:
|
||||
return parseOpaqueElementsAttr();
|
||||
|
||||
consumeToken(Token::string);
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
// Parse a sparse elements attribute.
|
||||
case Token::kw_sparse:
|
||||
return parseSparseElementsAttr();
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
// Parse a splat elements attribute.
|
||||
case Token::kw_splat:
|
||||
return parseSplatElementsAttr();
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
|
||||
if (getToken().getKind() != Token::string)
|
||||
return (emitError("opaque string should start with '0x'"), nullptr);
|
||||
// Parse a string attribute.
|
||||
case Token::string: {
|
||||
auto val = getToken().getStringValue();
|
||||
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
|
||||
return (emitError("opaque string should start with '0x'"), nullptr);
|
||||
val = val.substr(2);
|
||||
if (!std::all_of(val.begin(), val.end(),
|
||||
[](char c) { return llvm::isHexDigit(c); })) {
|
||||
return (emitError("opaque string only contains hex digits"), nullptr);
|
||||
}
|
||||
consumeToken(Token::string);
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
|
||||
return builder.getStringAttr(val);
|
||||
}
|
||||
case Token::kw_splat: {
|
||||
consumeToken(Token::kw_splat);
|
||||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
switch (getToken().getKind()) {
|
||||
case Token::floatliteral:
|
||||
case Token::integer:
|
||||
case Token::kw_false:
|
||||
case Token::kw_true:
|
||||
case Token::minus: {
|
||||
auto scalar = parseAttribute(type.getElementType());
|
||||
if (!scalar)
|
||||
return nullptr;
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getSplatElementsAttr(type, scalar);
|
||||
}
|
||||
default:
|
||||
return (emitError("expected scalar constant inside tensor literal"),
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
case Token::kw_dense: {
|
||||
consumeToken(Token::kw_dense);
|
||||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
// Parse a 'unit' attribute.
|
||||
case Token::kw_unit:
|
||||
consumeToken(Token::kw_unit);
|
||||
return builder.getUnitAttr();
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
|
||||
auto attr = parseDenseElementsAttr(type);
|
||||
if (!attr)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
|
||||
return attr;
|
||||
}
|
||||
case Token::kw_sparse: {
|
||||
consumeToken(Token::kw_sparse);
|
||||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
|
||||
switch (getToken().getKind()) {
|
||||
case Token::l_square: {
|
||||
/// Parse indices
|
||||
auto indicesEltType = builder.getIntegerType(64);
|
||||
auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
|
||||
if (!indices)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
/// Parse values.
|
||||
auto valuesEltType = type.getElementType();
|
||||
auto values = parseDenseElementsAttrAsTensor(valuesEltType);
|
||||
if (!values)
|
||||
return nullptr;
|
||||
|
||||
/// Sanity check.
|
||||
auto valuesType = values.getType();
|
||||
if (valuesType.getRank() != 1) {
|
||||
return (emitError("expected 1-d tensor for values"), nullptr);
|
||||
}
|
||||
auto indicesType = indices.getType();
|
||||
auto sameShape = (indicesType.getRank() == 1) ||
|
||||
(type.getRank() == indicesType.getDimSize(1));
|
||||
auto sameElementNum =
|
||||
indicesType.getDimSize(0) == valuesType.getDimSize(0);
|
||||
if (!sameShape || !sameElementNum) {
|
||||
emitError() << "expected shape ([" << type.getShape()
|
||||
<< "]); inferred shape of indices literal (["
|
||||
<< indicesType.getShape()
|
||||
<< "]); inferred shape of values literal (["
|
||||
<< valuesType.getShape() << "])";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
|
||||
// Build the sparse elements attribute by the indices and values.
|
||||
return builder.getSparseElementsAttr(
|
||||
type, indices.cast<DenseIntElementsAttr>(), values);
|
||||
}
|
||||
default:
|
||||
return (emitError("expected '[' to start sparse tensor literal"),
|
||||
nullptr);
|
||||
}
|
||||
return (emitError("expected elements literal has a tensor or vector type"),
|
||||
nullptr);
|
||||
}
|
||||
default: {
|
||||
default:
|
||||
// Parse a type attribute.
|
||||
if (Type type = parseType())
|
||||
return builder.getTypeAttr(type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attribute dictionary.
|
||||
|
@ -1289,7 +1097,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
|
|||
/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
|
||||
/// attribute-alias ::= `#` alias-name
|
||||
///
|
||||
Attribute Parser::parseExtendedAttribute(Type type) {
|
||||
Attribute Parser::parseExtendedAttr(Type type) {
|
||||
Attribute attr = parseExtendedSymbol<Attribute>(
|
||||
*this, Token::hash_identifier, state.attributeAliasDefinitions,
|
||||
[&](StringRef dialectName, StringRef symbolData,
|
||||
|
@ -1313,6 +1121,183 @@ Attribute Parser::parseExtendedAttribute(Type type) {
|
|||
return attr;
|
||||
}
|
||||
|
||||
/// Parse a float attribute.
|
||||
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
|
||||
auto val = getToken().getFloatingPointValue();
|
||||
if (!val.hasValue())
|
||||
return (emitError("floating point value too large for attribute"), nullptr);
|
||||
auto valTok = getToken().getLoc();
|
||||
consumeToken(Token::floatliteral);
|
||||
if (!type) {
|
||||
// Default to F64 when no type is specified.
|
||||
if (!consumeIf(Token::colon))
|
||||
type = builder.getF64Type();
|
||||
else if (!(type = parseType()))
|
||||
return nullptr;
|
||||
}
|
||||
if (!type.isa<FloatType>())
|
||||
return (emitError("floating point value not valid for specified type"),
|
||||
nullptr);
|
||||
return FloatAttr::getChecked(type,
|
||||
isNegative ? -val.getValue() : val.getValue(),
|
||||
getEncodedSourceLocation(valTok));
|
||||
}
|
||||
|
||||
/// Parse an integer attribute.
|
||||
Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
if (!val.hasValue() ||
|
||||
(isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
|
||||
return (emitError("integer constant out of range for attribute"), nullptr);
|
||||
consumeToken(Token::integer);
|
||||
if (!type) {
|
||||
// Default to i64 if not type is specified.
|
||||
if (!consumeIf(Token::colon))
|
||||
type = builder.getIntegerType(64);
|
||||
else if (!(type = parseType()))
|
||||
return nullptr;
|
||||
}
|
||||
if (!type.isIntOrIndex())
|
||||
return (emitError("integer value not valid for specified type"), nullptr);
|
||||
|
||||
int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
|
||||
APInt apInt(width, *val, isSigned);
|
||||
if (apInt != *val)
|
||||
return (emitError("integer constant out of range for attribute"), nullptr);
|
||||
return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
|
||||
}
|
||||
|
||||
/// Parse an opaque elements attribute.
|
||||
Attribute Parser::parseOpaqueElementsAttr() {
|
||||
consumeToken(Token::kw_opaque);
|
||||
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
||||
return nullptr;
|
||||
|
||||
if (getToken().isNot(Token::string))
|
||||
return (emitError("expected dialect namespace"), nullptr);
|
||||
|
||||
auto name = getToken().getStringValue();
|
||||
auto *dialect = builder.getContext()->getRegisteredDialect(name);
|
||||
// TODO(shpeisman): Allow for having an unknown dialect on an opaque
|
||||
// attribute. Otherwise, it can't be roundtripped without having the dialect
|
||||
// registered.
|
||||
if (!dialect)
|
||||
return (emitError("no registered dialect with namespace '" + name + "'"),
|
||||
nullptr);
|
||||
|
||||
consumeToken(Token::string);
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
if (getToken().getKind() != Token::string)
|
||||
return (emitError("opaque string should start with '0x'"), nullptr);
|
||||
|
||||
auto val = getToken().getStringValue();
|
||||
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
|
||||
return (emitError("opaque string should start with '0x'"), nullptr);
|
||||
|
||||
val = val.substr(2);
|
||||
if (!llvm::all_of(val, llvm::isHexDigit))
|
||||
return (emitError("opaque string only contains hex digits"), nullptr);
|
||||
|
||||
consumeToken(Token::string);
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
|
||||
}
|
||||
|
||||
/// Parse a sparse elements attribute.
|
||||
Attribute Parser::parseSparseElementsAttr() {
|
||||
consumeToken(Token::kw_sparse);
|
||||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
if (getToken().isNot(Token::l_square))
|
||||
return emitError("expected '[' to start sparse tensor literal"), nullptr;
|
||||
|
||||
/// Parse indices
|
||||
auto indicesEltType = builder.getIntegerType(64);
|
||||
auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
|
||||
if (!indices)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
/// Parse values.
|
||||
auto valuesEltType = type.getElementType();
|
||||
auto values = parseDenseElementsAttrAsTensor(valuesEltType);
|
||||
if (!values)
|
||||
return nullptr;
|
||||
|
||||
/// Sanity check.
|
||||
auto valuesType = values.getType();
|
||||
if (valuesType.getRank() != 1)
|
||||
return (emitError("expected 1-d tensor for values"), nullptr);
|
||||
|
||||
auto indicesType = indices.getType();
|
||||
auto sameShape = (indicesType.getRank() == 1) ||
|
||||
(type.getRank() == indicesType.getDimSize(1));
|
||||
auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
|
||||
if (!sameShape || !sameElementNum) {
|
||||
emitError() << "expected shape ([" << type.getShape()
|
||||
<< "]); inferred shape of indices literal (["
|
||||
<< indicesType.getShape()
|
||||
<< "]); inferred shape of values literal (["
|
||||
<< valuesType.getShape() << "])";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
|
||||
// Build the sparse elements attribute by the indices and values.
|
||||
return builder.getSparseElementsAttr(
|
||||
type, indices.cast<DenseIntElementsAttr>(), values);
|
||||
}
|
||||
|
||||
/// Parse a splat elements attribute.
|
||||
Attribute Parser::parseSplatElementsAttr() {
|
||||
consumeToken(Token::kw_splat);
|
||||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
|
||||
switch (getToken().getKind()) {
|
||||
case Token::floatliteral:
|
||||
case Token::integer:
|
||||
case Token::kw_false:
|
||||
case Token::kw_true:
|
||||
case Token::minus: {
|
||||
auto scalar = parseAttribute(type.getElementType());
|
||||
if (!scalar)
|
||||
return nullptr;
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getSplatElementsAttr(type, scalar);
|
||||
}
|
||||
default:
|
||||
return emitError("expected scalar constant inside tensor literal"), nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class TensorLiteralParser {
|
||||
public:
|
||||
|
@ -1457,9 +1442,19 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
|
|||
/// This method compares the shapes from the parsing result and that from the
|
||||
/// input argument. It returns a constructed dense elements attribute if both
|
||||
/// match.
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
|
||||
auto eltTy = type.getElementType();
|
||||
TensorLiteralParser literalParser(*this, eltTy);
|
||||
Attribute Parser::parseDenseElementsAttr() {
|
||||
consumeToken(Token::kw_dense);
|
||||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseElementsLiteralType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::comma, "expected ',' after elements literal type"))
|
||||
return nullptr;
|
||||
|
||||
TensorLiteralParser literalParser(*this, type.getElementType());
|
||||
if (literalParser.parse())
|
||||
return nullptr;
|
||||
|
||||
|
@ -1470,6 +1465,9 @@ DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
|
||||
return builder.getDenseElementsAttr(type, literalParser.getValues())
|
||||
.cast<DenseElementsAttr>();
|
||||
}
|
||||
|
@ -1504,9 +1502,8 @@ ShapedType Parser::parseElementsLiteralType() {
|
|||
return nullptr;
|
||||
|
||||
if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
|
||||
return (
|
||||
emitError("elements literal must be a ranked tensor or vector type"),
|
||||
nullptr);
|
||||
emitError("elements literal must be a ranked tensor or vector type");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto sType = type.cast<ShapedType>();
|
||||
|
|
Loading…
Reference in New Issue