forked from OSchip/llvm-project
Use APFloat for FloatAttribute
We should be able to represent arbitrary precision Float-point values inside the IR, so compiler optimizations, such as constant folding can be done independently on the compiling platform. This CL also added a new field, AttrValueGetter, to the Attr class definition for TableGen. This field is used to customize which mlir::Attr getter method to get the defined PrimitiveType. PiperOrigin-RevId: 218034983
This commit is contained in:
parent
2f1103bd93
commit
c5a3a5e4ca
|
@ -20,7 +20,8 @@
|
|||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
|
@ -121,15 +122,15 @@ private:
|
|||
int64_t value;
|
||||
};
|
||||
|
||||
class FloatAttr : public Attribute {
|
||||
class FloatAttr final : public Attribute,
|
||||
public llvm::TrailingObjects<FloatAttr, uint64_t> {
|
||||
public:
|
||||
static FloatAttr *get(double value, MLIRContext *context);
|
||||
static FloatAttr *get(const APFloat &value, MLIRContext *context);
|
||||
|
||||
// TODO: This should really be implemented in terms of APFloat for
|
||||
// correctness, otherwise constant folding will be done with host math. This
|
||||
// is completely incorrect for BF16 and other datatypes, and subtly wrong
|
||||
// for float32.
|
||||
double getValue() const { return value; }
|
||||
APFloat getValue() const;
|
||||
|
||||
double getDouble() const { return getValue().convertToDouble(); }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
|
@ -137,10 +138,18 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
FloatAttr(double value)
|
||||
: Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
|
||||
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
|
||||
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
|
||||
semantics(semantics), numObjects(numObjects) {}
|
||||
FloatAttr(const FloatAttr &value) = delete;
|
||||
~FloatAttr() = delete;
|
||||
double value;
|
||||
|
||||
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
|
||||
return numObjects;
|
||||
}
|
||||
|
||||
const llvm::fltSemantics &semantics;
|
||||
size_t numObjects;
|
||||
};
|
||||
|
||||
class StringAttr : public Attribute {
|
||||
|
|
|
@ -96,6 +96,7 @@ public:
|
|||
BoolAttr *getBoolAttr(bool value);
|
||||
IntegerAttr *getIntegerAttr(int64_t value);
|
||||
FloatAttr *getFloatAttr(double value);
|
||||
FloatAttr *getFloatAttr(const APFloat &value);
|
||||
StringAttr *getStringAttr(StringRef bytes);
|
||||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
||||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
||||
|
|
|
@ -116,10 +116,10 @@ protected:
|
|||
class ConstantFloatOp : public ConstantOp {
|
||||
public:
|
||||
/// Builds a constant float op producing a float of the specified type.
|
||||
static void build(Builder *builder, OperationState *result, double value,
|
||||
FloatType *type);
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
const APFloat &value, FloatType *type);
|
||||
|
||||
double getValue() const {
|
||||
APFloat getValue() const {
|
||||
return getAttrOfType<FloatAttr>("value")->getValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -373,9 +373,7 @@ void ModulePrinter::print(const Module *module) {
|
|||
|
||||
/// Print a floating point value in a way that the parser will be able to
|
||||
/// round-trip losslessly.
|
||||
static void printFloatValue(double value, raw_ostream &os) {
|
||||
APFloat apValue(value);
|
||||
|
||||
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
|
||||
// We would like to output the FP constant value in exponential notation,
|
||||
// but we cannot do this if doing so will lose precision. Check here to
|
||||
// make sure that we only output it in exponential format if we can parse
|
||||
|
@ -394,25 +392,15 @@ static void printFloatValue(double value, raw_ostream &os) {
|
|||
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
||||
"[-+]?[0-9] regex does not match!");
|
||||
// Reparse stringized version!
|
||||
if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
|
||||
if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
|
||||
os << strValue;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, print it in a hexadecimal form. Convert it to an integer so we
|
||||
// can print it out using integer math.
|
||||
union {
|
||||
double doubleValue;
|
||||
uint64_t integerValue;
|
||||
};
|
||||
doubleValue = value;
|
||||
os << "0x";
|
||||
// Print out 16 nibbles worth of hex digit.
|
||||
for (unsigned i = 0; i != 16; ++i) {
|
||||
os << llvm::hexdigit(integerValue >> 60);
|
||||
integerValue <<= 4;
|
||||
}
|
||||
SmallVector<char, 16> str;
|
||||
apValue.toString(str);
|
||||
os << str;
|
||||
}
|
||||
|
||||
void ModulePrinter::printFunctionReference(const Function *func) {
|
||||
|
|
|
@ -121,6 +121,10 @@ IntegerAttr *Builder::getIntegerAttr(int64_t value) {
|
|||
}
|
||||
|
||||
FloatAttr *Builder::getFloatAttr(double value) {
|
||||
return FloatAttr::get(APFloat(value), context);
|
||||
}
|
||||
|
||||
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
|
||||
return FloatAttr::get(value, context);
|
||||
}
|
||||
|
||||
|
|
|
@ -238,7 +238,7 @@ Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
}
|
||||
|
||||
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
||||
double value, FloatType *type) {
|
||||
const APFloat &value, FloatType *type) {
|
||||
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
@ -147,6 +146,21 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
|||
}
|
||||
};
|
||||
|
||||
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
|
||||
// Float attributes are uniqued based on wrapped APFloat.
|
||||
using KeyTy = APFloat;
|
||||
using DenseMapInfo<FloatAttr *>::getHashValue;
|
||||
using DenseMapInfo<FloatAttr *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs.bitwiseIsEqual(rhs->getValue());
|
||||
}
|
||||
};
|
||||
|
||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
|
||||
// Array attributes are uniqued based on their elements.
|
||||
using KeyTy = ArrayRef<Attribute *>;
|
||||
|
@ -282,7 +296,7 @@ public:
|
|||
// Attribute uniquing.
|
||||
BoolAttr *boolAttrs[2] = {nullptr};
|
||||
DenseMap<int64_t, IntegerAttr *> integerAttrs;
|
||||
DenseMap<int64_t, FloatAttr *> floatAttrs;
|
||||
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
|
||||
StringMap<StringAttr *> stringAttrs;
|
||||
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
||||
ArrayAttrSet arrayAttrs;
|
||||
|
@ -638,21 +652,36 @@ IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
|
|||
}
|
||||
|
||||
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
|
||||
// We hash based on the bit representation of the double to ensure we don't
|
||||
// merge things like -0.0 and 0.0 in the hash comparison.
|
||||
union {
|
||||
double floatValue;
|
||||
int64_t intValue;
|
||||
};
|
||||
floatValue = value;
|
||||
return get(APFloat(value), context);
|
||||
}
|
||||
|
||||
auto *&result = context->getImpl().floatAttrs[intValue];
|
||||
if (result)
|
||||
return result;
|
||||
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
result = context->getImpl().allocator.Allocate<FloatAttr>();
|
||||
new (result) FloatAttr(value);
|
||||
return result;
|
||||
// Look to see if the float attribute has been created already.
|
||||
auto existing = impl.floatAttrs.insert_as(nullptr, value);
|
||||
|
||||
// If it has been created, return it.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// If it doesn't, create one, unique it and return it.
|
||||
const auto &apint = value.bitcastToAPInt();
|
||||
// Here one word's bitwidth equals to that of uint64_t.
|
||||
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
||||
|
||||
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
|
||||
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
|
||||
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
|
||||
std::uninitialized_copy(elements.begin(), elements.end(),
|
||||
result->getTrailingObjects<uint64_t>());
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
APFloat FloatAttr::getValue() const {
|
||||
auto val = APInt(APFloat::getSizeInBits(semantics),
|
||||
{getTrailingObjects<uint64_t>(), numObjects});
|
||||
return APFloat(semantics, val);
|
||||
}
|
||||
|
||||
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||
|
|
|
@ -699,7 +699,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
case Type::Kind::F64: {
|
||||
if (!isa<FloatAttr>(result))
|
||||
return p.emitError("expected tensor literal element has float type");
|
||||
double value = cast<FloatAttr>(result)->getValue();
|
||||
double value = cast<FloatAttr>(result)->getDouble();
|
||||
addToStorage(*(uint64_t *)(&value));
|
||||
break;
|
||||
}
|
||||
|
@ -823,7 +823,7 @@ Attribute *Parser::parseAttribute() {
|
|||
return (emitError("floating point value too large for attribute"),
|
||||
nullptr);
|
||||
consumeToken(Token::floatliteral);
|
||||
return builder.getFloatAttr(val.getValue());
|
||||
return builder.getFloatAttr(APFloat(val.getValue()));
|
||||
}
|
||||
case Token::integer: {
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
|
@ -848,7 +848,7 @@ Attribute *Parser::parseAttribute() {
|
|||
return (emitError("floating point value too large for attribute"),
|
||||
nullptr);
|
||||
consumeToken(Token::floatliteral);
|
||||
return builder.getFloatAttr(-val.getValue());
|
||||
return builder.getFloatAttr(APFloat(-val.getValue()));
|
||||
}
|
||||
|
||||
return (emitError("expected constant integer or floating point value"),
|
||||
|
|
Loading…
Reference in New Issue