From c5a3a5e4cad7063122369257dbf0046397c59ef9 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Sat, 20 Oct 2018 18:31:49 -0700 Subject: [PATCH] Use APFloat for FloatAttribute We should be able to represent arbitrary precision Float-point values inside the IR, so compiler optimizations, such as constant folding can be done independently on the compiling platform. This CL also added a new field, AttrValueGetter, to the Attr class definition for TableGen. This field is used to customize which mlir::Attr getter method to get the defined PrimitiveType. PiperOrigin-RevId: 218034983 --- mlir/include/mlir/IR/Attributes.h | 29 +++++++++------ mlir/include/mlir/IR/Builders.h | 1 + mlir/include/mlir/IR/BuiltinOps.h | 6 ++-- mlir/lib/IR/AsmPrinter.cpp | 22 +++--------- mlir/lib/IR/Builders.cpp | 4 +++ mlir/lib/IR/BuiltinOps.cpp | 2 +- mlir/lib/IR/MLIRContext.cpp | 59 +++++++++++++++++++++++-------- mlir/lib/Parser/Parser.cpp | 6 ++-- 8 files changed, 80 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index a676bf276f53..28a3939fb1e1 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -20,7 +20,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/Support/TrailingObjects.h" namespace mlir { class Function; @@ -121,15 +122,15 @@ private: int64_t value; }; -class FloatAttr : public Attribute { +class FloatAttr final : public Attribute, + public llvm::TrailingObjects { public: static FloatAttr *get(double value, MLIRContext *context); + static FloatAttr *get(const APFloat &value, MLIRContext *context); - // TODO: This should really be implemented in terms of APFloat for - // correctness, otherwise constant folding will be done with host math. This - // is completely incorrect for BF16 and other datatypes, and subtly wrong - // for float32. - double getValue() const { return value; } + APFloat getValue() const; + + double getDouble() const { return getValue().convertToDouble(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { @@ -137,10 +138,18 @@ public: } private: - FloatAttr(double value) - : Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {} + FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects) + : Attribute(Kind::Float, /*isOrContainsFunction=*/false), + semantics(semantics), numObjects(numObjects) {} + FloatAttr(const FloatAttr &value) = delete; ~FloatAttr() = delete; - double value; + + size_t numTrailingObjects(OverloadToken) const { + return numObjects; + } + + const llvm::fltSemantics &semantics; + size_t numObjects; }; class StringAttr : public Attribute { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 4e44211c3444..0415637f3f07 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -96,6 +96,7 @@ public: BoolAttr *getBoolAttr(bool value); IntegerAttr *getIntegerAttr(int64_t value); FloatAttr *getFloatAttr(double value); + FloatAttr *getFloatAttr(const APFloat &value); StringAttr *getStringAttr(StringRef bytes); ArrayAttr *getArrayAttr(ArrayRef value); AffineMapAttr *getAffineMapAttr(AffineMap map); diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index fb2c2acdfd8e..9030f97572b9 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -116,10 +116,10 @@ protected: class ConstantFloatOp : public ConstantOp { public: /// Builds a constant float op producing a float of the specified type. - static void build(Builder *builder, OperationState *result, double value, - FloatType *type); + static void build(Builder *builder, OperationState *result, + const APFloat &value, FloatType *type); - double getValue() const { + APFloat getValue() const { return getAttrOfType("value")->getValue(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e70dcae29360..3c71f9a6b48c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -373,9 +373,7 @@ void ModulePrinter::print(const Module *module) { /// Print a floating point value in a way that the parser will be able to /// round-trip losslessly. -static void printFloatValue(double value, raw_ostream &os) { - APFloat apValue(value); - +static void printFloatValue(const APFloat &apValue, raw_ostream &os) { // We would like to output the FP constant value in exponential notation, // but we cannot do this if doing so will lose precision. Check here to // make sure that we only output it in exponential format if we can parse @@ -394,25 +392,15 @@ static void printFloatValue(double value, raw_ostream &os) { (strValue[1] >= '0' && strValue[1] <= '9'))) && "[-+]?[0-9] regex does not match!"); // Reparse stringized version! - if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) { + if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) { os << strValue; return; } } - // Otherwise, print it in a hexadecimal form. Convert it to an integer so we - // can print it out using integer math. - union { - double doubleValue; - uint64_t integerValue; - }; - doubleValue = value; - os << "0x"; - // Print out 16 nibbles worth of hex digit. - for (unsigned i = 0; i != 16; ++i) { - os << llvm::hexdigit(integerValue >> 60); - integerValue <<= 4; - } + SmallVector str; + apValue.toString(str); + os << str; } void ModulePrinter::printFunctionReference(const Function *func) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 3e22c852ee42..66192f0a867a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -121,6 +121,10 @@ IntegerAttr *Builder::getIntegerAttr(int64_t value) { } FloatAttr *Builder::getFloatAttr(double value) { + return FloatAttr::get(APFloat(value), context); +} + +FloatAttr *Builder::getFloatAttr(const APFloat &value) { return FloatAttr::get(value, context); } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 2acc26d73af3..fe943026ecbc 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -238,7 +238,7 @@ Attribute *ConstantOp::constantFold(ArrayRef operands, } void ConstantFloatOp::build(Builder *builder, OperationState *result, - double value, FloatType *type) { + const APFloat &value, FloatType *type) { ConstantOp::build(builder, result, builder->getFloatAttr(value), type); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 998342889262..1b9fe0cde63c 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -32,7 +32,6 @@ #include "mlir/IR/Types.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" -#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Allocator.h" @@ -147,6 +146,21 @@ struct MemRefTypeKeyInfo : DenseMapInfo { } }; +struct FloatAttrKeyInfo : DenseMapInfo { + // Float attributes are uniqued based on wrapped APFloat. + using KeyTy = APFloat; + using DenseMapInfo::getHashValue; + using DenseMapInfo::isEqual; + + static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); } + + static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs.bitwiseIsEqual(rhs->getValue()); + } +}; + struct ArrayAttrKeyInfo : DenseMapInfo { // Array attributes are uniqued based on their elements. using KeyTy = ArrayRef; @@ -282,7 +296,7 @@ public: // Attribute uniquing. BoolAttr *boolAttrs[2] = {nullptr}; DenseMap integerAttrs; - DenseMap floatAttrs; + DenseSet floatAttrs; StringMap stringAttrs; using ArrayAttrSet = DenseSet; ArrayAttrSet arrayAttrs; @@ -638,21 +652,36 @@ IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) { } FloatAttr *FloatAttr::get(double value, MLIRContext *context) { - // We hash based on the bit representation of the double to ensure we don't - // merge things like -0.0 and 0.0 in the hash comparison. - union { - double floatValue; - int64_t intValue; - }; - floatValue = value; + return get(APFloat(value), context); +} - auto *&result = context->getImpl().floatAttrs[intValue]; - if (result) - return result; +FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) { + auto &impl = context->getImpl(); - result = context->getImpl().allocator.Allocate(); - new (result) FloatAttr(value); - return result; + // Look to see if the float attribute has been created already. + auto existing = impl.floatAttrs.insert_as(nullptr, value); + + // If it has been created, return it. + if (!existing.second) + return *existing.first; + + // If it doesn't, create one, unique it and return it. + const auto &apint = value.bitcastToAPInt(); + // Here one word's bitwidth equals to that of uint64_t. + auto elements = ArrayRef(apint.getRawData(), apint.getNumWords()); + + auto byteSize = FloatAttr::totalSizeToAlloc(elements.size()); + auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr)); + auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + result->getTrailingObjects()); + return *existing.first = result; +} + +APFloat FloatAttr::getValue() const { + auto val = APInt(APFloat::getSizeInBits(semantics), + {getTrailingObjects(), numObjects}); + return APFloat(semantics, val); } StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index cbd3d8b7bd72..448f974b348e 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -699,7 +699,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { case Type::Kind::F64: { if (!isa(result)) return p.emitError("expected tensor literal element has float type"); - double value = cast(result)->getValue(); + double value = cast(result)->getDouble(); addToStorage(*(uint64_t *)(&value)); break; } @@ -823,7 +823,7 @@ Attribute *Parser::parseAttribute() { return (emitError("floating point value too large for attribute"), nullptr); consumeToken(Token::floatliteral); - return builder.getFloatAttr(val.getValue()); + return builder.getFloatAttr(APFloat(val.getValue())); } case Token::integer: { auto val = getToken().getUInt64IntegerValue(); @@ -848,7 +848,7 @@ Attribute *Parser::parseAttribute() { return (emitError("floating point value too large for attribute"), nullptr); consumeToken(Token::floatliteral); - return builder.getFloatAttr(-val.getValue()); + return builder.getFloatAttr(APFloat(-val.getValue())); } return (emitError("expected constant integer or floating point value"),