Use APFloat for FloatAttribute

We should be able to represent arbitrary precision Float-point values inside
the IR, so compiler optimizations, such as constant folding can be done
independently on the compiling platform.

This CL also added a new field, AttrValueGetter, to the Attr class definition
for TableGen. This field is used to customize which mlir::Attr getter method to
get the defined PrimitiveType.

PiperOrigin-RevId: 218034983
This commit is contained in:
Feng Liu 2018-10-20 18:31:49 -07:00 committed by jpienaar
parent 2f1103bd93
commit c5a3a5e4ca
8 changed files with 80 additions and 49 deletions

View File

@ -20,7 +20,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
class Function;
@ -121,15 +122,15 @@ private:
int64_t value;
};
class FloatAttr : public Attribute {
class FloatAttr final : public Attribute,
public llvm::TrailingObjects<FloatAttr, uint64_t> {
public:
static FloatAttr *get(double value, MLIRContext *context);
static FloatAttr *get(const APFloat &value, MLIRContext *context);
// TODO: This should really be implemented in terms of APFloat for
// correctness, otherwise constant folding will be done with host math. This
// is completely incorrect for BF16 and other datatypes, and subtly wrong
// for float32.
double getValue() const { return value; }
APFloat getValue() const;
double getDouble() const { return getValue().convertToDouble(); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
@ -137,10 +138,18 @@ public:
}
private:
FloatAttr(double value)
: Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
semantics(semantics), numObjects(numObjects) {}
FloatAttr(const FloatAttr &value) = delete;
~FloatAttr() = delete;
double value;
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
return numObjects;
}
const llvm::fltSemantics &semantics;
size_t numObjects;
};
class StringAttr : public Attribute {

View File

@ -96,6 +96,7 @@ public:
BoolAttr *getBoolAttr(bool value);
IntegerAttr *getIntegerAttr(int64_t value);
FloatAttr *getFloatAttr(double value);
FloatAttr *getFloatAttr(const APFloat &value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap map);

View File

@ -116,10 +116,10 @@ protected:
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result, double value,
FloatType *type);
static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType *type);
double getValue() const {
APFloat getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
}

View File

@ -373,9 +373,7 @@ void ModulePrinter::print(const Module *module) {
/// Print a floating point value in a way that the parser will be able to
/// round-trip losslessly.
static void printFloatValue(double value, raw_ostream &os) {
APFloat apValue(value);
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
// We would like to output the FP constant value in exponential notation,
// but we cannot do this if doing so will lose precision. Check here to
// make sure that we only output it in exponential format if we can parse
@ -394,25 +392,15 @@ static void printFloatValue(double value, raw_ostream &os) {
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
"[-+]?[0-9] regex does not match!");
// Reparse stringized version!
if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
os << strValue;
return;
}
}
// Otherwise, print it in a hexadecimal form. Convert it to an integer so we
// can print it out using integer math.
union {
double doubleValue;
uint64_t integerValue;
};
doubleValue = value;
os << "0x";
// Print out 16 nibbles worth of hex digit.
for (unsigned i = 0; i != 16; ++i) {
os << llvm::hexdigit(integerValue >> 60);
integerValue <<= 4;
}
SmallVector<char, 16> str;
apValue.toString(str);
os << str;
}
void ModulePrinter::printFunctionReference(const Function *func) {

View File

@ -121,6 +121,10 @@ IntegerAttr *Builder::getIntegerAttr(int64_t value) {
}
FloatAttr *Builder::getFloatAttr(double value) {
return FloatAttr::get(APFloat(value), context);
}
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
return FloatAttr::get(value, context);
}

View File

@ -238,7 +238,7 @@ Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
double value, FloatType *type) {
const APFloat &value, FloatType *type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
}

View File

@ -32,7 +32,6 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
@ -147,6 +146,21 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
}
};
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
// Float attributes are uniqued based on wrapped APFloat.
using KeyTy = APFloat;
using DenseMapInfo<FloatAttr *>::getHashValue;
using DenseMapInfo<FloatAttr *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.bitwiseIsEqual(rhs->getValue());
}
};
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute *>;
@ -282,7 +296,7 @@ public:
// Attribute uniquing.
BoolAttr *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttr *> integerAttrs;
DenseMap<int64_t, FloatAttr *> floatAttrs;
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttr *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
@ -638,21 +652,36 @@ IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
}
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
// We hash based on the bit representation of the double to ensure we don't
// merge things like -0.0 and 0.0 in the hash comparison.
union {
double floatValue;
int64_t intValue;
};
floatValue = value;
return get(APFloat(value), context);
}
auto *&result = context->getImpl().floatAttrs[intValue];
if (result)
return result;
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
auto &impl = context->getImpl();
result = context->getImpl().allocator.Allocate<FloatAttr>();
new (result) FloatAttr(value);
return result;
// Look to see if the float attribute has been created already.
auto existing = impl.floatAttrs.insert_as(nullptr, value);
// If it has been created, return it.
if (!existing.second)
return *existing.first;
// If it doesn't, create one, unique it and return it.
const auto &apint = value.bitcastToAPInt();
// Here one word's bitwidth equals to that of uint64_t.
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return *existing.first = result;
}
APFloat FloatAttr::getValue() const {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {

View File

@ -699,7 +699,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Type::Kind::F64: {
if (!isa<FloatAttr>(result))
return p.emitError("expected tensor literal element has float type");
double value = cast<FloatAttr>(result)->getValue();
double value = cast<FloatAttr>(result)->getDouble();
addToStorage(*(uint64_t *)(&value));
break;
}
@ -823,7 +823,7 @@ Attribute *Parser::parseAttribute() {
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
return builder.getFloatAttr(val.getValue());
return builder.getFloatAttr(APFloat(val.getValue()));
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
@ -848,7 +848,7 @@ Attribute *Parser::parseAttribute() {
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
return builder.getFloatAttr(-val.getValue());
return builder.getFloatAttr(APFloat(-val.getValue()));
}
return (emitError("expected constant integer or floating point value"),