Add a FloatAttr::getChecked, and invoke it during Attribute parsing.

PiperOrigin-RevId: 229167099
This commit is contained in:
River Riddle 2019-01-14 05:37:14 -08:00 committed by jpienaar
parent 1b171e9357
commit 791049fb34
3 changed files with 26 additions and 9 deletions

View File

@ -28,6 +28,7 @@ class Function;
class FunctionAttr;
class FunctionType;
class IntegerSet;
class Location;
class MLIRContext;
class Type;
class VectorOrTensorType;
@ -176,6 +177,8 @@ public:
static FloatAttr get(Type type, double value);
static FloatAttr get(Type type, const APFloat &value);
static FloatAttr getChecked(Type type, double value, Location loc);
APFloat getValue() const;
/// This function is used to convert the value to a double, even if it loses

View File

@ -818,7 +818,8 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
return get(type, APInt(intType.getWidth(), value));
}
FloatAttr FloatAttr::get(Type type, double value) {
static FloatAttr getFloatAttr(Type type, double value,
llvm::Optional<Location> loc) {
Optional<APFloat> val;
if (type.isBF16())
// Treat BF16 as double because it is not supported in LLVM's APFloat.
@ -836,14 +837,23 @@ FloatAttr FloatAttr::get(Type type, double value) {
auto status = (*val).convert(fltType.getFloatSemantics(),
APFloat::rmTowardZero, &unused);
if (status != APFloat::opOK) {
auto context = type.getContext();
context->emitError(
UnknownLoc::get(context),
"failed to convert floating point value to requested type");
val.reset();
if (loc)
type.getContext()->emitError(
*loc, "failed to convert floating point value to requested type");
return nullptr;
}
}
return get(type, *val);
return FloatAttr::get(type, *val);
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
return getFloatAttr(type, value, loc);
}
FloatAttr FloatAttr::get(Type type, double value) {
auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
assert(res && "failed to construct float attribute");
return res;
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {

View File

@ -875,6 +875,7 @@ Attribute Parser::parseAttribute(Type type) {
if (!val.hasValue())
return (emitError("floating point value too large for attribute"),
nullptr);
auto valTok = getToken().getLoc();
consumeToken(Token::floatliteral);
if (!type) {
if (consumeIf(Token::colon)) {
@ -888,7 +889,8 @@ Attribute Parser::parseAttribute(Type type) {
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for specified type"),
nullptr);
return builder.getFloatAttr(type, val.getValue());
return FloatAttr::getChecked(type, val.getValue(),
getEncodedSourceLocation(valTok));
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
@ -945,6 +947,7 @@ Attribute Parser::parseAttribute(Type type) {
if (!val.hasValue())
return (emitError("floating point value too large for attribute"),
nullptr);
auto valTok = getToken().getLoc();
consumeToken(Token::floatliteral);
if (!type) {
if (consumeIf(Token::colon)) {
@ -957,7 +960,8 @@ Attribute Parser::parseAttribute(Type type) {
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for type"), nullptr);
return builder.getFloatAttr(type, -val.getValue());
return FloatAttr::getChecked(type, -val.getValue(),
getEncodedSourceLocation(valTok));
}
return (emitError("expected constant integer or floating point value"),