forked from OSchip/llvm-project
Use APFloat for FloatAttribute
We should be able to represent arbitrary precision Float-point values inside the IR, so compiler optimizations, such as constant folding can be done independently on the compiling platform. This CL also added a new field, AttrValueGetter, to the Attr class definition for TableGen. This field is used to customize which mlir::Attr getter method to get the defined PrimitiveType. PiperOrigin-RevId: 218034983
This commit is contained in:
parent
2f1103bd93
commit
c5a3a5e4ca
|
@ -20,7 +20,8 @@
|
||||||
|
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/APFloat.h"
|
||||||
|
#include "llvm/Support/TrailingObjects.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Function;
|
class Function;
|
||||||
|
@ -121,15 +122,15 @@ private:
|
||||||
int64_t value;
|
int64_t value;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FloatAttr : public Attribute {
|
class FloatAttr final : public Attribute,
|
||||||
|
public llvm::TrailingObjects<FloatAttr, uint64_t> {
|
||||||
public:
|
public:
|
||||||
static FloatAttr *get(double value, MLIRContext *context);
|
static FloatAttr *get(double value, MLIRContext *context);
|
||||||
|
static FloatAttr *get(const APFloat &value, MLIRContext *context);
|
||||||
|
|
||||||
// TODO: This should really be implemented in terms of APFloat for
|
APFloat getValue() const;
|
||||||
// correctness, otherwise constant folding will be done with host math. This
|
|
||||||
// is completely incorrect for BF16 and other datatypes, and subtly wrong
|
double getDouble() const { return getValue().convertToDouble(); }
|
||||||
// for float32.
|
|
||||||
double getValue() const { return value; }
|
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool classof(const Attribute *attr) {
|
||||||
|
@ -137,10 +138,18 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FloatAttr(double value)
|
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
|
||||||
: Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
|
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
|
||||||
|
semantics(semantics), numObjects(numObjects) {}
|
||||||
|
FloatAttr(const FloatAttr &value) = delete;
|
||||||
~FloatAttr() = delete;
|
~FloatAttr() = delete;
|
||||||
double value;
|
|
||||||
|
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
|
||||||
|
return numObjects;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llvm::fltSemantics &semantics;
|
||||||
|
size_t numObjects;
|
||||||
};
|
};
|
||||||
|
|
||||||
class StringAttr : public Attribute {
|
class StringAttr : public Attribute {
|
||||||
|
|
|
@ -96,6 +96,7 @@ public:
|
||||||
BoolAttr *getBoolAttr(bool value);
|
BoolAttr *getBoolAttr(bool value);
|
||||||
IntegerAttr *getIntegerAttr(int64_t value);
|
IntegerAttr *getIntegerAttr(int64_t value);
|
||||||
FloatAttr *getFloatAttr(double value);
|
FloatAttr *getFloatAttr(double value);
|
||||||
|
FloatAttr *getFloatAttr(const APFloat &value);
|
||||||
StringAttr *getStringAttr(StringRef bytes);
|
StringAttr *getStringAttr(StringRef bytes);
|
||||||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
||||||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
||||||
|
|
|
@ -116,10 +116,10 @@ protected:
|
||||||
class ConstantFloatOp : public ConstantOp {
|
class ConstantFloatOp : public ConstantOp {
|
||||||
public:
|
public:
|
||||||
/// Builds a constant float op producing a float of the specified type.
|
/// Builds a constant float op producing a float of the specified type.
|
||||||
static void build(Builder *builder, OperationState *result, double value,
|
static void build(Builder *builder, OperationState *result,
|
||||||
FloatType *type);
|
const APFloat &value, FloatType *type);
|
||||||
|
|
||||||
double getValue() const {
|
APFloat getValue() const {
|
||||||
return getAttrOfType<FloatAttr>("value")->getValue();
|
return getAttrOfType<FloatAttr>("value")->getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -373,9 +373,7 @@ void ModulePrinter::print(const Module *module) {
|
||||||
|
|
||||||
/// Print a floating point value in a way that the parser will be able to
|
/// Print a floating point value in a way that the parser will be able to
|
||||||
/// round-trip losslessly.
|
/// round-trip losslessly.
|
||||||
static void printFloatValue(double value, raw_ostream &os) {
|
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
|
||||||
APFloat apValue(value);
|
|
||||||
|
|
||||||
// We would like to output the FP constant value in exponential notation,
|
// We would like to output the FP constant value in exponential notation,
|
||||||
// but we cannot do this if doing so will lose precision. Check here to
|
// but we cannot do this if doing so will lose precision. Check here to
|
||||||
// make sure that we only output it in exponential format if we can parse
|
// make sure that we only output it in exponential format if we can parse
|
||||||
|
@ -394,25 +392,15 @@ static void printFloatValue(double value, raw_ostream &os) {
|
||||||
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
||||||
"[-+]?[0-9] regex does not match!");
|
"[-+]?[0-9] regex does not match!");
|
||||||
// Reparse stringized version!
|
// Reparse stringized version!
|
||||||
if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) {
|
if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
|
||||||
os << strValue;
|
os << strValue;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, print it in a hexadecimal form. Convert it to an integer so we
|
SmallVector<char, 16> str;
|
||||||
// can print it out using integer math.
|
apValue.toString(str);
|
||||||
union {
|
os << str;
|
||||||
double doubleValue;
|
|
||||||
uint64_t integerValue;
|
|
||||||
};
|
|
||||||
doubleValue = value;
|
|
||||||
os << "0x";
|
|
||||||
// Print out 16 nibbles worth of hex digit.
|
|
||||||
for (unsigned i = 0; i != 16; ++i) {
|
|
||||||
os << llvm::hexdigit(integerValue >> 60);
|
|
||||||
integerValue <<= 4;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printFunctionReference(const Function *func) {
|
void ModulePrinter::printFunctionReference(const Function *func) {
|
||||||
|
|
|
@ -121,6 +121,10 @@ IntegerAttr *Builder::getIntegerAttr(int64_t value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *Builder::getFloatAttr(double value) {
|
FloatAttr *Builder::getFloatAttr(double value) {
|
||||||
|
return FloatAttr::get(APFloat(value), context);
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
|
||||||
return FloatAttr::get(value, context);
|
return FloatAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -238,7 +238,7 @@ Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
||||||
double value, FloatType *type) {
|
const APFloat &value, FloatType *type) {
|
||||||
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Support/MathExtras.h"
|
#include "mlir/Support/MathExtras.h"
|
||||||
#include "mlir/Support/STLExtras.h"
|
#include "mlir/Support/STLExtras.h"
|
||||||
#include "llvm/ADT/APInt.h"
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/StringMap.h"
|
#include "llvm/ADT/StringMap.h"
|
||||||
#include "llvm/Support/Allocator.h"
|
#include "llvm/Support/Allocator.h"
|
||||||
|
@ -147,6 +146,21 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
|
||||||
|
// Float attributes are uniqued based on wrapped APFloat.
|
||||||
|
using KeyTy = APFloat;
|
||||||
|
using DenseMapInfo<FloatAttr *>::getHashValue;
|
||||||
|
using DenseMapInfo<FloatAttr *>::isEqual;
|
||||||
|
|
||||||
|
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
||||||
|
|
||||||
|
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
|
||||||
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
|
return false;
|
||||||
|
return lhs.bitwiseIsEqual(rhs->getValue());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
|
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
|
||||||
// Array attributes are uniqued based on their elements.
|
// Array attributes are uniqued based on their elements.
|
||||||
using KeyTy = ArrayRef<Attribute *>;
|
using KeyTy = ArrayRef<Attribute *>;
|
||||||
|
@ -282,7 +296,7 @@ public:
|
||||||
// Attribute uniquing.
|
// Attribute uniquing.
|
||||||
BoolAttr *boolAttrs[2] = {nullptr};
|
BoolAttr *boolAttrs[2] = {nullptr};
|
||||||
DenseMap<int64_t, IntegerAttr *> integerAttrs;
|
DenseMap<int64_t, IntegerAttr *> integerAttrs;
|
||||||
DenseMap<int64_t, FloatAttr *> floatAttrs;
|
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
|
||||||
StringMap<StringAttr *> stringAttrs;
|
StringMap<StringAttr *> stringAttrs;
|
||||||
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
||||||
ArrayAttrSet arrayAttrs;
|
ArrayAttrSet arrayAttrs;
|
||||||
|
@ -638,21 +652,36 @@ IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
|
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
|
||||||
// We hash based on the bit representation of the double to ensure we don't
|
return get(APFloat(value), context);
|
||||||
// merge things like -0.0 and 0.0 in the hash comparison.
|
}
|
||||||
union {
|
|
||||||
double floatValue;
|
|
||||||
int64_t intValue;
|
|
||||||
};
|
|
||||||
floatValue = value;
|
|
||||||
|
|
||||||
auto *&result = context->getImpl().floatAttrs[intValue];
|
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||||
if (result)
|
auto &impl = context->getImpl();
|
||||||
return result;
|
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<FloatAttr>();
|
// Look to see if the float attribute has been created already.
|
||||||
new (result) FloatAttr(value);
|
auto existing = impl.floatAttrs.insert_as(nullptr, value);
|
||||||
return result;
|
|
||||||
|
// If it has been created, return it.
|
||||||
|
if (!existing.second)
|
||||||
|
return *existing.first;
|
||||||
|
|
||||||
|
// If it doesn't, create one, unique it and return it.
|
||||||
|
const auto &apint = value.bitcastToAPInt();
|
||||||
|
// Here one word's bitwidth equals to that of uint64_t.
|
||||||
|
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
||||||
|
|
||||||
|
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
|
||||||
|
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
|
||||||
|
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
|
||||||
|
std::uninitialized_copy(elements.begin(), elements.end(),
|
||||||
|
result->getTrailingObjects<uint64_t>());
|
||||||
|
return *existing.first = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
APFloat FloatAttr::getValue() const {
|
||||||
|
auto val = APInt(APFloat::getSizeInBits(semantics),
|
||||||
|
{getTrailingObjects<uint64_t>(), numObjects});
|
||||||
|
return APFloat(semantics, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
|
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||||
|
|
|
@ -699,7 +699,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
||||||
case Type::Kind::F64: {
|
case Type::Kind::F64: {
|
||||||
if (!isa<FloatAttr>(result))
|
if (!isa<FloatAttr>(result))
|
||||||
return p.emitError("expected tensor literal element has float type");
|
return p.emitError("expected tensor literal element has float type");
|
||||||
double value = cast<FloatAttr>(result)->getValue();
|
double value = cast<FloatAttr>(result)->getDouble();
|
||||||
addToStorage(*(uint64_t *)(&value));
|
addToStorage(*(uint64_t *)(&value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -823,7 +823,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
return (emitError("floating point value too large for attribute"),
|
return (emitError("floating point value too large for attribute"),
|
||||||
nullptr);
|
nullptr);
|
||||||
consumeToken(Token::floatliteral);
|
consumeToken(Token::floatliteral);
|
||||||
return builder.getFloatAttr(val.getValue());
|
return builder.getFloatAttr(APFloat(val.getValue()));
|
||||||
}
|
}
|
||||||
case Token::integer: {
|
case Token::integer: {
|
||||||
auto val = getToken().getUInt64IntegerValue();
|
auto val = getToken().getUInt64IntegerValue();
|
||||||
|
@ -848,7 +848,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
return (emitError("floating point value too large for attribute"),
|
return (emitError("floating point value too large for attribute"),
|
||||||
nullptr);
|
nullptr);
|
||||||
consumeToken(Token::floatliteral);
|
consumeToken(Token::floatliteral);
|
||||||
return builder.getFloatAttr(-val.getValue());
|
return builder.getFloatAttr(APFloat(-val.getValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return (emitError("expected constant integer or floating point value"),
|
return (emitError("expected constant integer or floating point value"),
|
||||||
|
|
Loading…
Reference in New Issue