forked from OSchip/llvm-project
Densify storage for f16, f32 and support f16 semantics in FloatAttrs
Existing implementation always uses 64 bits to store floating point values in DenseElementsAttr. This was due to FloatAttrs always a `double` for storage independently of the actual type. Recent commits added support for FloatAttrs with the proper f32 type and floating semantics and changed the bitwidth reporting on FloatType. Use the existing infrastructure for densely storing 16 and 32-bit values in DenseElementsAttr storage to store f16 and f32 values. Move floating semantics definition to the FloatType level. Properly support f16 / IEEEhalf semantics at the FloatAttr level and in the builder. Note that bf16 is still stored as a 64-bit value with IEEEdouble semantics because APFloat does not have first-class support for bf16 types. PiperOrigin-RevId: 225981289
This commit is contained in:
parent
20531932f4
commit
49c81ebcb0
|
@ -369,6 +369,16 @@ public:
|
|||
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
|
||||
/// in array `rawData`.
|
||||
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value);
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
||||
/// `rawData` and return them as the lowest bits of an uint64 integer.
|
||||
static uint64_t readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth);
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
|
||||
|
@ -389,16 +399,6 @@ public:
|
|||
|
||||
APInt getValue(ArrayRef<unsigned> indices) const;
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
|
||||
/// in array `rawData`.
|
||||
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value);
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
||||
/// `rawData` and return them as the lowest bits of an uint64 integer.
|
||||
static uint64_t readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth);
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
|
||||
};
|
||||
|
|
|
@ -22,6 +22,10 @@
|
|||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
namespace llvm {
|
||||
class fltSemantics;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class FloatType;
|
||||
|
@ -254,6 +258,9 @@ public:
|
|||
|
||||
/// Return the bitwidth of this float type.
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Return the floating semantics of this float type.
|
||||
const llvm::fltSemantics &getFloatSemantics() const;
|
||||
};
|
||||
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
|
|
|
@ -159,13 +159,10 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|||
return static_cast<ImplType *>(attr)->data;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
|
||||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
||||
/// starting from `rawData`.
|
||||
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value) {
|
||||
void DenseElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value) {
|
||||
// Read the destination bytes which will be written to.
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
|
@ -189,8 +186,8 @@ void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
|||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
||||
/// and put them in the lowest bits.
|
||||
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth) {
|
||||
uint64_t DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth) {
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitsWidth;
|
||||
|
@ -203,6 +200,9 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|||
return dst;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
|
||||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||
auto elementNum = getType().getNumElements();
|
||||
|
@ -230,14 +230,38 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|||
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
|
||||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
// Construct a FloatAttr wrapping a float value of `elementType` type from its
|
||||
// bit representation. The APFloat stored in the attribute will have the
|
||||
// semantics defined by the float semantics of the element type.
|
||||
static inline FloatAttr makeFloatAttrFromBits(size_t bitWidth, uint64_t bits,
|
||||
FloatType elementType) {
|
||||
auto apint = APInt(bitWidth, bits);
|
||||
auto apfloat = APFloat(elementType.getFloatSemantics(), apint);
|
||||
return FloatAttr::get(elementType, apfloat);
|
||||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementNum = getType().getNumElements();
|
||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
||||
assert(elementType && "non-float type in FP attribute");
|
||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
||||
// semantics.
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
|
||||
values.reserve(elementNum);
|
||||
for (auto v : vs) {
|
||||
auto attr = FloatAttr::get(getType().getElementType(), v);
|
||||
values.push_back(attr);
|
||||
if (bitWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
for (auto bitValue : vs) {
|
||||
values.push_back(makeFloatAttrFromBits(64, bitValue, elementType));
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < elementNum; ++i) {
|
||||
uint64_t bits = readBits(getRawData().data(), i * bitWidth, bitWidth);
|
||||
values.push_back(makeFloatAttrFromBits(bitWidth, bits, elementType));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1178,42 +1178,23 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
|||
return get(type, APInt(intType.getWidth(), value));
|
||||
}
|
||||
|
||||
/// Returns the floating semantics for the given type.
|
||||
static const fltSemantics &getFloatSemantics(Type type) {
|
||||
if (type.isBF16())
|
||||
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
|
||||
// not defined in LLVM.
|
||||
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
|
||||
// else one could add it.
|
||||
// static const fltSemantics semBF16 = {127, -126, 8, 16};
|
||||
return APFloat::IEEEdouble();
|
||||
if (type.isF16())
|
||||
// Treat F16 as double. This avoids needing to change the tensor element
|
||||
// parsing for now. TODO: Fix this to use the correct semantics instead.
|
||||
return APFloat::IEEEdouble();
|
||||
if (type.isF32())
|
||||
return APFloat::IEEEsingle();
|
||||
if (type.isF64())
|
||||
return APFloat::IEEEdouble();
|
||||
llvm_unreachable("non-floating point type used");
|
||||
}
|
||||
|
||||
FloatAttr FloatAttr::get(Type type, double value) {
|
||||
Optional<APFloat> val;
|
||||
if (type.isBF16() || type.isF16())
|
||||
// Treat BF16 and F16 as double. This avoids needing to change the tensor
|
||||
// element parsing for now. TODO: Fix this to use the correct semantics
|
||||
// instead.
|
||||
if (type.isBF16())
|
||||
// Treat BF16 as double because it is not supported in LLVM's APFloat.
|
||||
// TODO(jpienaar): add BF16 support to APFloat?
|
||||
val = APFloat(value);
|
||||
else if (type.isF32())
|
||||
val = APFloat(static_cast<float>(value));
|
||||
else if (type.isF64())
|
||||
val = APFloat(value);
|
||||
else {
|
||||
// This handles, e.g., F16 because there is no APFloat constructor for it.
|
||||
bool unused;
|
||||
val = APFloat(value);
|
||||
auto status =
|
||||
(*val).convert(getFloatSemantics(type), APFloat::rmTowardZero, &unused);
|
||||
auto fltType = type.cast<FloatType>();
|
||||
auto status = (*val).convert(fltType.getFloatSemantics(),
|
||||
APFloat::rmTowardZero, &unused);
|
||||
if (status != APFloat::opOK) {
|
||||
auto context = type.getContext();
|
||||
context->emitError(
|
||||
|
@ -1226,8 +1207,10 @@ FloatAttr FloatAttr::get(Type type, double value) {
|
|||
}
|
||||
|
||||
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
|
||||
assert(&getFloatSemantics(type) == &value.getSemantics() &&
|
||||
auto fltType = type.cast<FloatType>();
|
||||
assert(&fltType.getFloatSemantics() == &value.getSemantics() &&
|
||||
"FloatAttr type doesn't match the type implied by its value");
|
||||
(void)fltType;
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if the float attribute has been created already.
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -55,6 +56,24 @@ unsigned FloatType::getWidth() const {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the floating semantics for the given type.
|
||||
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
|
||||
if (isBF16())
|
||||
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
|
||||
// not defined in LLVM.
|
||||
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
|
||||
// else one could add it.
|
||||
// static const fltSemantics semBF16 = {127, -126, 8, 16};
|
||||
return APFloat::IEEEdouble();
|
||||
if (isF16())
|
||||
return APFloat::IEEEhalf();
|
||||
if (isF32())
|
||||
return APFloat::IEEEsingle();
|
||||
if (isF64())
|
||||
return APFloat::IEEEdouble();
|
||||
llvm_unreachable("non-floating point type used");
|
||||
}
|
||||
|
||||
unsigned Type::getIntOrFloatBitWidth() const {
|
||||
assert(isIntOrFloat() && "only ints and floats have a bitwidth");
|
||||
if (auto intType = dyn_cast<IntegerType>()) {
|
||||
|
|
|
@ -649,11 +649,9 @@ private:
|
|||
|
||||
void addToStorage(uint64_t value) {
|
||||
// Only tensors of integers or floats are supported.
|
||||
// TODO: we currently use 64 bit for all floating point constants for legacy
|
||||
// reasoins. For f16 and f32, this is fixable by bitcasting APFloat value
|
||||
// to APInt, but APFloat does not support bf16 semantics.
|
||||
auto eltIntTy = eltTy.dyn_cast<IntegerType>();
|
||||
size_t bitWidth = eltIntTy ? eltIntTy.getWidth() : 64;
|
||||
// FIXME: use full word to store BF16 as double because APFloat, which we
|
||||
// use to work with floats, does not have support for BF16 yet.
|
||||
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
|
||||
|
||||
if (bitWidth == 64)
|
||||
storage.push_back(value);
|
||||
|
@ -662,7 +660,7 @@ private:
|
|||
storage.push_back(0L);
|
||||
|
||||
auto *rawData = reinterpret_cast<char *>(storage.data());
|
||||
DenseIntElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
|
||||
DenseElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
|
||||
currBitPos += bitWidth;
|
||||
}
|
||||
|
||||
|
@ -689,22 +687,24 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
return ParseResult::ParseFailure;
|
||||
// check result matches the element type.
|
||||
switch (eltTy.getKind()) {
|
||||
case Type::Kind::F32: {
|
||||
if (!result.isa<FloatAttr>())
|
||||
return p.emitError(
|
||||
"expected tensor literal element has floating point type");
|
||||
double value = result.cast<FloatAttr>().getValue().convertToFloat();
|
||||
addToStorage(*(uint64_t *)(&value));
|
||||
break;
|
||||
}
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
case Type::Kind::F64: {
|
||||
if (!result.isa<FloatAttr>())
|
||||
// Bitcast the APFloat value to APInt and store the bit representation.
|
||||
auto fpAttrResult = result.dyn_cast<FloatAttr>();
|
||||
if (!fpAttrResult)
|
||||
return p.emitError(
|
||||
"expected tensor literal element has floating point type");
|
||||
double value = result.cast<FloatAttr>().getDouble();
|
||||
addToStorage(*(uint64_t *)(&value));
|
||||
"expected tensor literal element with floating point type");
|
||||
auto apInt = fpAttrResult.getValue().bitcastToAPInt();
|
||||
|
||||
// FIXME: using 64 bits and double semantics for BF16 because APFloat does
|
||||
// not support BF16 directly.
|
||||
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
|
||||
assert(apInt.getBitWidth() == bitWidth);
|
||||
(void)bitWidth;
|
||||
|
||||
addToStorage(apInt.getRawData()[0]);
|
||||
break;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
|
|
Loading…
Reference in New Issue