Densify storage for f16, f32 and support f16 semantics in FloatAttrs

Existing implementation always uses 64 bits to store floating point values in
DenseElementsAttr.  This was due to FloatAttrs always a `double` for storage
independently of the actual type.  Recent commits added support for FloatAttrs
with the proper f32 type and floating semantics and changed the bitwidth
reporting on FloatType.

Use the existing infrastructure for densely storing 16 and 32-bit values in
DenseElementsAttr storage to store f16 and f32 values.  Move floating semantics
definition to the FloatType level.  Properly support f16 / IEEEhalf semantics
at the FloatAttr level and in the builder.

Note that bf16 is still stored as a 64-bit value with IEEEdouble semantics
because APFloat does not have first-class support for bf16 types.

PiperOrigin-RevId: 225981289
This commit is contained in:
Alex Zinenko 2018-12-18 05:25:17 -08:00 committed by jpienaar
parent 20531932f4
commit 49c81ebcb0
6 changed files with 100 additions and 67 deletions

View File

@ -369,6 +369,16 @@ public:
ArrayRef<char> getRawData() const; ArrayRef<char> getRawData() const;
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
/// in array `rawData`.
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
uint64_t value);
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData` and return them as the lowest bits of an uint64 integer.
static uint64_t readBits(const char *rawData, size_t bitPos,
size_t bitsWidth);
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { static bool kindof(Kind kind) {
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements; return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
@ -389,16 +399,6 @@ public:
APInt getValue(ArrayRef<unsigned> indices) const; APInt getValue(ArrayRef<unsigned> indices) const;
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
/// in array `rawData`.
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
uint64_t value);
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData` and return them as the lowest bits of an uint64 integer.
static uint64_t readBits(const char *rawData, size_t bitPos,
size_t bitsWidth);
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; } static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
}; };

View File

@ -22,6 +22,10 @@
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseMapInfo.h"
namespace llvm {
class fltSemantics;
} // namespace llvm
namespace mlir { namespace mlir {
class AffineMap; class AffineMap;
class FloatType; class FloatType;
@ -254,6 +258,9 @@ public:
/// Return the bitwidth of this float type. /// Return the bitwidth of this float type.
unsigned getWidth() const; unsigned getWidth() const;
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics() const;
}; };
inline FloatType Type::getBF16(MLIRContext *ctx) { inline FloatType Type::getBF16(MLIRContext *ctx) {

View File

@ -159,13 +159,10 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(attr)->data; return static_cast<ImplType *>(attr)->data;
} }
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos` /// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`. /// starting from `rawData`.
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth, void DenseElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) { uint64_t value) {
// Read the destination bytes which will be written to. // Read the destination bytes which will be written to.
uint64_t dst = 0; uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst); auto dstData = reinterpret_cast<char *>(&dst);
@ -189,8 +186,8 @@ void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData` /// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits. /// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos, uint64_t DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) { size_t bitsWidth) {
uint64_t dst = 0; uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst); auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth; auto endPos = bitPos + bitsWidth;
@ -203,6 +200,9 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
return dst; return dst;
} }
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth; auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType().getNumElements(); auto elementNum = getType().getNumElements();
@ -230,14 +230,38 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr) DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {} : DenseElementsAttr(ptr) {}
// Construct a FloatAttr wrapping a float value of `elementType` type from its
// bit representation. The APFloat stored in the attribute will have the
// semantics defined by the float semantics of the element type.
static inline FloatAttr makeFloatAttrFromBits(size_t bitWidth, uint64_t bits,
FloatType elementType) {
auto apint = APInt(bitWidth, bits);
auto apfloat = APFloat(elementType.getFloatSemantics(), apint);
return FloatAttr::get(elementType, apfloat);
}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType().getNumElements(); auto elementNum = getType().getNumElements();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()), auto elementType = getType().getElementType().dyn_cast<FloatType>();
getRawData().size() / 8}); assert(elementType && "non-float type in FP attribute");
// FIXME: using 64 bits for BF16 because it is currently stored with double
// semantics.
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
values.reserve(elementNum); values.reserve(elementNum);
for (auto v : vs) { if (bitWidth == 64) {
auto attr = FloatAttr::get(getType().getElementType(), v); ArrayRef<int64_t> vs(
values.push_back(attr); {reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto bitValue : vs) {
values.push_back(makeFloatAttrFromBits(64, bitValue, elementType));
}
return;
}
for (unsigned i = 0; i < elementNum; ++i) {
uint64_t bits = readBits(getRawData().data(), i * bitWidth, bitWidth);
values.push_back(makeFloatAttrFromBits(bitWidth, bits, elementType));
} }
} }

View File

@ -1178,42 +1178,23 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) {
return get(type, APInt(intType.getWidth(), value)); return get(type, APInt(intType.getWidth(), value));
} }
/// Returns the floating semantics for the given type.
static const fltSemantics &getFloatSemantics(Type type) {
if (type.isBF16())
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
// not defined in LLVM.
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
// else one could add it.
// static const fltSemantics semBF16 = {127, -126, 8, 16};
return APFloat::IEEEdouble();
if (type.isF16())
// Treat F16 as double. This avoids needing to change the tensor element
// parsing for now. TODO: Fix this to use the correct semantics instead.
return APFloat::IEEEdouble();
if (type.isF32())
return APFloat::IEEEsingle();
if (type.isF64())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
FloatAttr FloatAttr::get(Type type, double value) { FloatAttr FloatAttr::get(Type type, double value) {
Optional<APFloat> val; Optional<APFloat> val;
if (type.isBF16() || type.isF16()) if (type.isBF16())
// Treat BF16 and F16 as double. This avoids needing to change the tensor // Treat BF16 as double because it is not supported in LLVM's APFloat.
// element parsing for now. TODO: Fix this to use the correct semantics // TODO(jpienaar): add BF16 support to APFloat?
// instead.
val = APFloat(value); val = APFloat(value);
else if (type.isF32()) else if (type.isF32())
val = APFloat(static_cast<float>(value)); val = APFloat(static_cast<float>(value));
else if (type.isF64()) else if (type.isF64())
val = APFloat(value); val = APFloat(value);
else { else {
// This handles, e.g., F16 because there is no APFloat constructor for it.
bool unused; bool unused;
val = APFloat(value); val = APFloat(value);
auto status = auto fltType = type.cast<FloatType>();
(*val).convert(getFloatSemantics(type), APFloat::rmTowardZero, &unused); auto status = (*val).convert(fltType.getFloatSemantics(),
APFloat::rmTowardZero, &unused);
if (status != APFloat::opOK) { if (status != APFloat::opOK) {
auto context = type.getContext(); auto context = type.getContext();
context->emitError( context->emitError(
@ -1226,8 +1207,10 @@ FloatAttr FloatAttr::get(Type type, double value) {
} }
FloatAttr FloatAttr::get(Type type, const APFloat &value) { FloatAttr FloatAttr::get(Type type, const APFloat &value) {
assert(&getFloatSemantics(type) == &value.getSemantics() && auto fltType = type.cast<FloatType>();
assert(&fltType.getFloatSemantics() == &value.getSemantics() &&
"FloatAttr type doesn't match the type implied by its value"); "FloatAttr type doesn't match the type implied by its value");
(void)fltType;
auto &impl = type.getContext()->getImpl(); auto &impl = type.getContext()->getImpl();
// Look to see if the float attribute has been created already. // Look to see if the float attribute has been created already.

View File

@ -19,6 +19,7 @@
#include "TypeDetail.h" #include "TypeDetail.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
using namespace mlir; using namespace mlir;
@ -55,6 +56,24 @@ unsigned FloatType::getWidth() const {
} }
} }
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
if (isBF16())
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
// not defined in LLVM.
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
// else one could add it.
// static const fltSemantics semBF16 = {127, -126, 8, 16};
return APFloat::IEEEdouble();
if (isF16())
return APFloat::IEEEhalf();
if (isF32())
return APFloat::IEEEsingle();
if (isF64())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
unsigned Type::getIntOrFloatBitWidth() const { unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only ints and floats have a bitwidth"); assert(isIntOrFloat() && "only ints and floats have a bitwidth");
if (auto intType = dyn_cast<IntegerType>()) { if (auto intType = dyn_cast<IntegerType>()) {

View File

@ -649,11 +649,9 @@ private:
void addToStorage(uint64_t value) { void addToStorage(uint64_t value) {
// Only tensors of integers or floats are supported. // Only tensors of integers or floats are supported.
// TODO: we currently use 64 bit for all floating point constants for legacy // FIXME: use full word to store BF16 as double because APFloat, which we
// reasoins. For f16 and f32, this is fixable by bitcasting APFloat value // use to work with floats, does not have support for BF16 yet.
// to APInt, but APFloat does not support bf16 semantics. size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
auto eltIntTy = eltTy.dyn_cast<IntegerType>();
size_t bitWidth = eltIntTy ? eltIntTy.getWidth() : 64;
if (bitWidth == 64) if (bitWidth == 64)
storage.push_back(value); storage.push_back(value);
@ -662,7 +660,7 @@ private:
storage.push_back(0L); storage.push_back(0L);
auto *rawData = reinterpret_cast<char *>(storage.data()); auto *rawData = reinterpret_cast<char *>(storage.data());
DenseIntElementsAttr::writeBits(rawData, currBitPos, bitWidth, value); DenseElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
currBitPos += bitWidth; currBitPos += bitWidth;
} }
@ -689,22 +687,24 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
return ParseResult::ParseFailure; return ParseResult::ParseFailure;
// check result matches the element type. // check result matches the element type.
switch (eltTy.getKind()) { switch (eltTy.getKind()) {
case Type::Kind::F32: {
if (!result.isa<FloatAttr>())
return p.emitError(
"expected tensor literal element has floating point type");
double value = result.cast<FloatAttr>().getValue().convertToFloat();
addToStorage(*(uint64_t *)(&value));
break;
}
case Type::Kind::BF16: case Type::Kind::BF16:
case Type::Kind::F16: case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: { case Type::Kind::F64: {
if (!result.isa<FloatAttr>()) // Bitcast the APFloat value to APInt and store the bit representation.
auto fpAttrResult = result.dyn_cast<FloatAttr>();
if (!fpAttrResult)
return p.emitError( return p.emitError(
"expected tensor literal element has floating point type"); "expected tensor literal element with floating point type");
double value = result.cast<FloatAttr>().getDouble(); auto apInt = fpAttrResult.getValue().bitcastToAPInt();
addToStorage(*(uint64_t *)(&value));
// FIXME: using 64 bits and double semantics for BF16 because APFloat does
// not support BF16 directly.
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
assert(apInt.getBitWidth() == bitWidth);
(void)bitWidth;
addToStorage(apInt.getRawData()[0]);
break; break;
} }
case Type::Kind::Integer: { case Type::Kind::Integer: {