Support hexadecimal floats in tensor literals

Extend the recently introduced support for hexadecimal float literals to tensor
literals, which may also contain special floating point values such as
infinities and NaNs.

Modify TensorLiteralParser to store the list of tokens representing values
until the type is parsed instead of trying to guess the tensor element type
from the token kinds (hexadecimal values can be either integers or floats, and
can be mixed with both).  Maintain the error reports as close as possible to
the existing implementation to avoid disturbing the tests.  They can be
improved in a separate clean-up if deemed necessary.

PiperOrigin-RevId: 260794716
This commit is contained in:
Alex Zinenko 2019-07-30 14:24:30 -07:00 committed by A. Unique TensorFlower
parent 1de519a753
commit 206be96e63
3 changed files with 175 additions and 116 deletions

View File

@ -1152,6 +1152,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
}
/// Construct a float attribute bitwise equivalent to the integer literal.
static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
uint64_t value) {
int width = type.getIntOrFloatBitWidth();
APInt apInt(width, value);
if (apInt != value) {
p->emitError("hexadecimal float constant out of range for type");
return nullptr;
}
APFloat apFloat(type.getFloatSemantics(), apInt);
return p->builder.getFloatAttr(type, apFloat);
}
/// Parse a decimal or a hexadecimal literal, which can be either an integer
/// or a float attribute.
Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
@ -1188,14 +1201,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
// Construct a float attribute bitwise equivalent to the integer literal.
int width = type.getIntOrFloatBitWidth();
APInt apInt(width, *val, isNegative);
if (apInt != *val) {
emitError("hexadecimal float constant out of range for attribute");
return nullptr;
}
APFloat apFloat(floatType.getFloatSemantics(), apInt);
return builder.getFloatAttr(type, apFloat);
return buildHexadecimalFloatLiteral(this, floatType, *val);
}
if (!type.isIntOrIndex())
@ -1306,14 +1312,6 @@ private:
/// parseElement([1]) -> Failure
ParseResult parseElement();
/// Parse an integer element value, returning failure if the value isn't
/// valid.
ParseResult parseIntegerElement(bool isSigned);
/// Parse a floating-point element value, returning failure if the value isn't
/// valid.
ParseResult parseFloatElement(bool isNegative);
/// Parse a list of either lists or elements, returning the dimensions of the
/// parsed sub-tensors in dims. For example:
/// parseList([1, 2, 3]) -> Success, [3]
@ -1327,12 +1325,8 @@ private:
/// The shape inferred from the parsed elements.
SmallVector<int64_t, 4> shape;
/// Storage used when parsing integer elements, this is a pair of <is_signed,
/// value>.
std::vector<std::pair<bool, uint64_t>> intStorage;
/// Storage used when parsing float elements.
std::vector<double> floatStorage;
/// Storage used when parsing elements, this is a pair of <is_negated, token>.
std::vector<std::pair<bool, Token>> storage;
/// A flag that indicates the type of elements that have been parsed.
llvm::Optional<ElementKind> knownEltKind;
@ -1370,21 +1364,43 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
ShapedType type,
IntegerType eltTy) {
// Check to see if floating point values were parsed.
if (!floatStorage.empty()) {
p.emitError() << "expected integer elements, but parsed floating-point";
return nullptr;
std::vector<APInt> intElements;
intElements.reserve(storage.size());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
// Check to see if floating point values were parsed.
if (token.is(Token::floatliteral)) {
p.emitError() << "expected integer elements, but parsed floating-point";
return nullptr;
}
assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger(1))
p.emitError() << "expected i1 type for 'true' or 'false' values";
APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
/*isSigned=*/false);
intElements.push_back(apInt);
continue;
}
// Create APInt values for each element with the correct bitwidth.
auto val = token.getUInt64IntegerValue();
if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0
: (int64_t)val.getValue() < 0)) {
p.emitError(token.getLoc(),
"integer constant out of range for attribute");
return nullptr;
}
APInt apInt(eltTy.getWidth(), val.getValue(), isNegative);
if (apInt != val.getValue())
return (p.emitError("integer constant out of range for type"), nullptr);
intElements.push_back(isNegative ? -apInt : apInt);
}
// Create APInt values for each element with the correct bitwidth.
std::vector<APInt> intElements;
intElements.reserve(intStorage.size());
for (auto &signAndValue : intStorage) {
APInt apInt(eltTy.getWidth(), signAndValue.second, signAndValue.first);
if (apInt != signAndValue.second)
return (p.emitError("integer constant out of range for type"), nullptr);
intElements.push_back(signAndValue.first ? -apInt : apInt);
}
return DenseElementsAttr::get(type, intElements);
}
@ -1392,109 +1408,73 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
ShapedType type,
FloatType eltTy) {
// Check to see if integer values were parsed.
if (!intStorage.empty()) {
p.emitError() << "expected floating-point elements, but parsed integer";
return nullptr;
std::vector<Attribute> floatValues;
floatValues.reserve(storage.size());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
if (isNegative) {
p.emitError(token.getLoc())
<< "hexadecimal float literal should not have a leading minus";
return nullptr;
}
auto val = token.getUInt64IntegerValue();
if (!val.hasValue()) {
p.emitError("hexadecimal float constant out of range for attribute");
return nullptr;
}
FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
if (!attr)
return nullptr;
floatValues.push_back(attr);
continue;
}
// Check to see if any decimal integers or booleans were parsed.
if (!token.is(Token::floatliteral)) {
p.emitError() << "expected floating-point elements, but parsed integer";
return nullptr;
}
// Build the float values from tokens.
auto val = token.getFloatingPointValue();
if (!val.hasValue()) {
p.emitError("floating point value too large for attribute");
return nullptr;
}
floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
}
// Build the float values from the raw integer storage.
std::vector<Attribute> floatValues;
floatValues.reserve(floatStorage.size());
for (auto &elt : floatStorage)
floatValues.push_back(FloatAttr::get(eltTy, elt));
return DenseElementsAttr::get(type, floatValues);
}
ParseResult TensorLiteralParser::parseElement() {
auto loc = p.getToken().getLoc();
ElementKind newEltKind;
switch (p.getToken().getKind()) {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
intStorage.emplace_back(false, p.getToken().is(Token::kw_true));
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
p.consumeToken();
newEltKind = ElementKind::Boolean;
break;
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
// Otherwise, check for an integer value.
if (p.getToken().is(Token::integer)) {
if (parseIntegerElement(/*isSigned=*/true))
return failure();
newEltKind = ElementKind::Integer;
// Otherwise, check for a floating point value.
} else if (p.getToken().is(Token::floatliteral)) {
if (parseFloatElement(/*isNegative=*/true))
return failure();
newEltKind = ElementKind::Float;
} else {
if (!p.getToken().isAny(Token::floatliteral, Token::integer))
return p.emitError("expected integer or floating point literal");
}
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
break;
// Parse a floating-point element.
case Token::floatliteral:
if (parseFloatElement(/*isNegative=*/false))
return failure();
newEltKind = ElementKind::Float;
break;
// Parse an integer element.
case Token::integer:
if (parseIntegerElement(/*isSigned=*/false))
return failure();
newEltKind = ElementKind::Integer;
break;
default:
return p.emitError("expected element literal of primitive type");
}
// Check to see if the element kind has changed from the previously inferred
// type.
if (!knownEltKind)
knownEltKind = newEltKind;
else if (knownEltKind != newEltKind)
return p.emitError(loc)
<< "tensor element type differs from previously inferred type, with "
"old type of "
<< getElementKindStr(*knownEltKind) << ", and new type of "
<< getElementKindStr(newEltKind);
return success();
}
/// Parse an integer element value, returning failure if the value isn't
/// valid.
ParseResult TensorLiteralParser::parseIntegerElement(bool isSigned) {
// Check that the integer value is valid.
auto val = p.getToken().getUInt64IntegerValue();
if (!val.hasValue() ||
(isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
return p.emitError("integer constant out of range for attribute");
// Add it to the storage.
p.consumeToken(Token::integer);
intStorage.emplace_back(isSigned, *val);
return success();
}
/// Parse a floating-point element value, returning failure if the value isn't
/// valid.
ParseResult TensorLiteralParser::parseFloatElement(bool isNegative) {
// Check that the float value is valid.
auto val = p.getToken().getFloatingPointValue();
if (!val.hasValue())
return p.emitError("floating point value too large for attribute");
// Add it to the storage.
p.consumeToken(Token::floatliteral);
floatStorage.push_back(isNegative ? -val.getValue() : val.getValue());
return success();
}

View File

@ -1062,7 +1062,7 @@ func @hexadecimal_float_leading_minus() {
// -----
func @hexadecimal_float_literal_overflow() {
// expected-error @+1 {{hexadecimal float constant out of range for attribute}}
// expected-error @+1 {{hexadecimal float constant out of range for type}}
"foo"() {value = 0xffffffff : f16} : () -> ()
}
@ -1073,3 +1073,69 @@ func @decimal_float_literal() {
// expected-note @+1 {{add a trailing dot to make the literal a float}}
"foo"() {value = 42 : f32} : () -> ()
}
// -----
func @float_in_int_tensor() {
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
"foo"() {bar = dense<[42.0, 42]> : tensor<2xi32>} : () -> ()
}
// -----
func @float_in_bool_tensor() {
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
"foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
}
// -----
func @decimal_int_in_float_tensor() {
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
"foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
}
// -----
func @bool_in_float_tensor() {
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
"foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
}
// -----
func @hexadecimal_float_leading_minus_in_tensor() {
// expected-error @+1 {{hexadecimal float literal should not have a leading minus}}
"foo"() {bar = dense<-0x7FFFFFFF> : tensor<2xf32>} : () -> ()
}
// -----
// Check that we report an error when a value could be parsed, but does not fit
// into the specified type.
func @hexadecimal_float_too_wide_for_type_in_tensor() {
// expected-error @+1 {{hexadecimal float constant out of range for type}}
"foo"() {bar = dense<0x7FF0000000000000> : tensor<2xf32>} : () -> ()
}
// -----
// Check that we report an error when a value is too wide to be parsed.
func @hexadecimal_float_too_wide_in_tensor() {
// expected-error @+1 {{hexadecimal float constant out of range for attribute}}
"foo"() {bar = dense<0x7FFFFFF0000000000000> : tensor<2xf32>} : () -> ()
}
// -----
func @integer_too_wide_in_tensor() {
// expected-error @+1 {{integer constant out of range for type}}
"foo"() {bar = dense<0xFFFFFFFFFFFFFF> : tensor<2xi16>} : () -> ()
}
// -----
func @bool_literal_in_non_bool_tensor() {
// expected-error @+1 {{expected i1 type for 'true' or 'false' values}}
"foo"() {bar = dense<true> : tensor<2xi16>} : () -> ()
}

View File

@ -1023,3 +1023,16 @@ func @f32_potential_precision_loss() {
%0 = constant -1.23697901 : f32
return
}
// CHECK-LABEL: @special_float_values_in_tensors
func @special_float_values_in_tensors() {
// CHECK: dense<0xFFFFFFFF> : tensor<4x4xf32>
"foo"(){bar = dense<0xFFFFFFFF> : tensor<4x4xf32>} : () -> ()
// CHECK: dense<[{{\[}}0xFFFFFFFF, 0x7F800000], [0x7FBFFFFF, 0x7F800001]]> : tensor<2x2xf32>
"foo"(){bar = dense<[[0xFFFFFFFF, 0x7F800000], [0x7FBFFFFF, 0x7F800001]]> : tensor<2x2xf32>} : () -> ()
// CHECK: dense<[0xFFFFFFFF, 0.000000e+00]> : tensor<2xf32>
"foo"(){bar = dense<[0xFFFFFFFF, 0.0]> : tensor<2xf32>} : () -> ()
// CHECK: sparse<[{{\[}}1, 1, 0], [0, 1, 1]], [0xFFFFFFFF, 0x7F800001]>
"foo"(){bar = sparse<[[1,1,0],[0,1,1]], [0xFFFFFFFF, 0x7F800001]> : tensor<2x2x2xf32>} : () -> ()
}