Fix builder getFloatAttr of double to use F64 type and use fltSemantics in FloatAttr.

Store FloatAttr using more appropriate fltSemantics (mostly fixing up F32/F64 storage, F16/BF16 pending). Previously F32 type was used incorrectly for double (the storage was double). Also add query method that returns fltSemantics for IEEE fp types and use that to verify that the APfloat given matches the type:
* FloatAttr created using APFloat is verified that the semantics of the type and APFloat matches;
* FloatAttr created using double has the APFloat created to match the semantics of the type;

Change parsing of tensor negative splat element to pass in the element type expected. Misc other changes to account for the storage type matching the attribute.

PiperOrigin-RevId: 225821834
This commit is contained in:
Jacques Pienaar 2018-12-17 07:19:53 -08:00 committed by jpienaar
parent 72159f5ede
commit 49c4d2a630
7 changed files with 99 additions and 28 deletions

View File

@ -106,6 +106,7 @@ public:
// TODO(jpienaar): remove these.
IntegerAttr getIntegerAttr(int64_t value);
FloatAttr getFloatAttr(double value);
FloatAttr getFloatAttr(float value);
StringAttr getStringAttr(StringRef bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map);

View File

@ -392,7 +392,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
"[-+]?[0-9] regex does not match!");
// Reparse stringized version!
if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
os << strValue;
return;
}

View File

@ -136,11 +136,15 @@ IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
}
FloatAttr Builder::getFloatAttr(double value) {
return FloatAttr::get(getF64Type(), APFloat(value));
}
FloatAttr Builder::getFloatAttr(float value) {
return FloatAttr::get(getF32Type(), APFloat(value));
}
FloatAttr Builder::getFloatAttr(Type type, double value) {
return FloatAttr::get(type, APFloat(value));
return FloatAttr::get(type, value);
}
FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {

View File

@ -1171,11 +1171,56 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
return get(type, APInt(width, value));
}
/// Returns the floating semantics for the given type.
static const fltSemantics &getFloatSemantics(Type type) {
if (type.isBF16())
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
// not defined in LLVM.
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
// else one could add it.
// static const fltSemantics semBF16 = {127, -126, 8, 16};
return APFloat::IEEEdouble();
if (type.isF16())
// Treat F16 as double. This avoids needing to change the tensor element
// parsing for now. TODO: Fix this to use the correct semantics instead.
return APFloat::IEEEdouble();
if (type.isF32())
return APFloat::IEEEsingle();
if (type.isF64())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
FloatAttr FloatAttr::get(Type type, double value) {
return get(type, APFloat(value));
Optional<APFloat> val;
if (type.isBF16() || type.isF16())
// Treat BF16 and F16 as double. This avoids needing to change the tensor
// element parsing for now. TODO: Fix this to use the correct semantics
// instead.
val = APFloat(value);
else if (type.isF32())
val = APFloat(static_cast<float>(value));
else if (type.isF64())
val = APFloat(value);
else {
bool unused;
val = APFloat(value);
auto status =
(*val).convert(getFloatSemantics(type), APFloat::rmTowardZero, &unused);
if (status != APFloat::opOK) {
auto context = type.getContext();
context->emitError(
UnknownLoc::get(context),
"failed to convert floating point value to requested type");
val.reset();
}
}
return get(type, *val);
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
assert(&getFloatSemantics(type) == &value.getSemantics() &&
"FloatAttr type doesn't match the type implied by its value");
auto &impl = type.getContext()->getImpl();
// Look to see if the float attribute has been created already.
@ -1393,6 +1438,8 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
// TODO(fengliuai): Add verification that the Attribute matches the element
// type.
auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.

View File

@ -678,17 +678,25 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto result = p.parseAttribute();
auto result = p.parseAttribute(eltTy);
if (!result)
return p.emitError("expected tensor element");
return ParseResult::ParseFailure;
// check result matches the element type.
switch (eltTy.getKind()) {
case Type::Kind::F32: {
if (!result.isa<FloatAttr>())
return p.emitError(
"expected tensor literal element has floating point type");
double value = result.cast<FloatAttr>().getValue().convertToFloat();
addToStorage(*(uint64_t *)(&value));
break;
}
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
if (!result.isa<FloatAttr>())
return p.emitError("expected tensor literal element has float type");
return p.emitError(
"expected tensor literal element has floating point type");
double value = result.cast<FloatAttr>().getDouble();
addToStorage(*(uint64_t *)(&value));
break;
@ -815,19 +823,20 @@ Attribute Parser::parseAttribute(Type type) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to F32 when no type is specified.
type = builder.getF32Type();
// Default to F64 when no type is specified.
type = builder.getF64Type();
}
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for specified type"),
nullptr);
return builder.getFloatAttr(type, APFloat(val.getValue()));
return builder.getFloatAttr(type, val.getValue());
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
return (emitError("integer constant out of range for attribute"),
nullptr);
consumeToken(Token::integer);
if (!type) {
if (consumeIf(Token::colon)) {
@ -841,7 +850,10 @@ Attribute Parser::parseAttribute(Type type) {
if (!type.isIntOrIndex())
return (emitError("integer value not valid for specified type"), nullptr);
int width = type.isIndex() ? 64 : type.getBitWidth();
return builder.getIntegerAttr(type, APInt(width, val.getValue()));
APInt apInt(width, val.getValue());
if (apInt != *val)
return emitError("integer constant out of range for attribute"), nullptr;
return builder.getIntegerAttr(type, apInt);
}
case Token::minus: {
@ -849,7 +861,8 @@ Attribute Parser::parseAttribute(Type type) {
if (getToken().is(Token::integer)) {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
return (emitError("integer constant out of range for attribute"),
nullptr);
consumeToken(Token::integer);
if (!type) {
if (consumeIf(Token::colon)) {
@ -863,7 +876,11 @@ Attribute Parser::parseAttribute(Type type) {
if (!type.isIntOrIndex())
return (emitError("integer value not valid for type"), nullptr);
int width = type.isIndex() ? 64 : type.getBitWidth();
return builder.getIntegerAttr(type, -APInt(width, val.getValue()));
APInt apInt(width, *val, /*isSigned=*/true);
if (apInt != *val)
return (emitError("integer constant out of range for attribute"),
nullptr);
return builder.getIntegerAttr(type, -apInt);
}
if (getToken().is(Token::floatliteral)) {
auto val = getToken().getFloatingPointValue();
@ -882,7 +899,7 @@ Attribute Parser::parseAttribute(Type type) {
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for type"), nullptr);
return builder.getFloatAttr(type, APFloat(-val.getValue()));
return builder.getFloatAttr(type, -val.getValue());
}
return (emitError("expected constant integer or floating point value"),

View File

@ -702,28 +702,30 @@ bb0:
cfgfunc @elementsattr_floattype1() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi32>, [4.0]>} : () -> () // expected-error {{expected tensor literal element has integer type}}
// expected-error@+1 {{floating point value not valid for specified type}}
"foo"(){bar: dense<tensor<1xi32>, [4.0]>} : () -> ()
}
// -----
cfgfunc @elementsattr_floattype2() -> () {
bb0:
"foo"(){bar: dense<tensor<1xf32>, [4]>} : () -> () // expected-error {{expected tensor literal element has float type}}
// expected-error@+1 {{integer value not valid for specified type}}
"foo"(){bar: dense<tensor<1xf32>, [4]>} : () -> ()
}
// -----
cfgfunc @elementsattr_toolarge1() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi8>, [777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
"foo"(){bar: dense<tensor<1xi8>, [777]>} : () -> () // expected-error {{integer constant out of range for attribute}}
}
// -----
cfgfunc @elementsattr_toolarge2() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi8>, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
"foo"(){bar: dense<tensor<1xi8>, [-777]>} : () -> () // expected-error {{integer constant out of range for attribute}}
}
// -----

View File

@ -518,18 +518,18 @@ mlfunc @mlfuncattrempty() -> ()
#map_non_simple2 = ()[s0, s1] -> (s0 + s1)
#map_non_simple3 = ()[s0] -> (s0 + 3)
mlfunc @mlfuncsimplemap(%arg0 : index, %arg1 : index) -> () {
for %i0 = 0 to #map_simple0()[] {
for %i0 = 0 to #map_simple0()[] {
// CHECK: for %i0 = 0 to 10 {
for %i1 = 0 to #map_simple1()[%arg1] {
for %i1 = 0 to #map_simple1()[%arg1] {
// CHECK: for %i1 = 0 to %arg1 {
for %i2 = 0 to #map_non_simple0(%i0)[] {
for %i2 = 0 to #map_non_simple0(%i0)[] {
// CHECK: for %i2 = 0 to #map{{[a-z_0-9]*}}(%i0) {
for %i3 = 0 to #map_non_simple1(%i0)[%arg1] {
for %i3 = 0 to #map_non_simple1(%i0)[%arg1] {
// CHECK: for %i3 = 0 to #map{{[a-z_0-9]*}}(%i0)[%arg1] {
for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] {
// CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] {
for %i5 = 0 to #map_non_simple3()[%arg0] {
// CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] {
for %i4 = 0 to #map_non_simple2()[%arg1, %arg0] {
// CHECK: for %i4 = 0 to #map{{[a-z_0-9]*}}()[%arg1, %arg0] {
for %i5 = 0 to #map_non_simple3()[%arg0] {
// CHECK: for %i5 = 0 to #map{{[a-z_0-9]*}}()[%arg0] {
%c42_i32 = constant 42 : i32
}
}
@ -704,7 +704,7 @@ bb0:
"fooi64"(){bar: sparse<vector<1xi64>, [[0]], [-1]>} : () -> ()
// CHECK: "foo2"() {bar: sparse<vector<0xi32>, {{\[}}], {{\[}}]>} : () -> ()
"foo2"(){bar: sparse<vector<0 x i32>, [], []>} : () -> ()
// CHECK: "foof16"() {bar: sparse<vector<1x1x1xf16>, {{\[\[}}0, 0, 0]], {{\[}}-2.000000e+00]>} : () -> ()
"foof16"(){bar: sparse<vector<1x1x1xf16>, [[0, 0, 0]], [-2.0]>} : () -> ()
// CHECK: "foobf16"() {bar: sparse<vector<2x2x2xbf16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]>} : () -> ()