From 4a7364f1c2ef0c45d7e603799fe0b7662d4c4078 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 19 Feb 2020 10:28:53 -0800 Subject: [PATCH] [mlir][Parser] Use APFloat instead of FloatAttr when parsing DenseElementsAttrs. Summary: DenseElementsAttr stores float values as raw bits internally, so creating attributes just to have them unwrapped is extremely inefficient. Differential Revision: https://reviews.llvm.org/D74818 --- mlir/lib/Parser/Parser.cpp | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 8bd57a11888c..2a2219c4202f 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1734,22 +1734,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { } /// Construct a float attribute bitwise equivalent to the integer literal. -static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, - uint64_t value) { +static Optional buildHexadecimalFloatLiteral(Parser *p, FloatType type, + uint64_t value) { // FIXME: bfloat is currently stored as a double internally because it doesn't // have valid APFloat semantics. - if (type.isF64() || type.isBF16()) { - APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); - return p->builder.getFloatAttr(type, apFloat); - } + if (type.isF64() || type.isBF16()) + return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); APInt apInt(type.getWidth(), value); if (apInt != value) { p->emitError("hexadecimal float constant out of range for type"); - return nullptr; + return llvm::None; } - APFloat apFloat(type.getFloatSemantics(), apInt); - return p->builder.getFloatAttr(type, apFloat); + return APFloat(type.getFloatSemantics(), apInt); } /// Parse a decimal or a hexadecimal literal, which can be either an integer @@ -1787,7 +1784,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { } // Construct a float attribute bitwise equivalent to the integer literal. - return buildHexadecimalFloatLiteral(this, floatType, *val); + Optional apVal = + buildHexadecimalFloatLiteral(this, floatType, *val); + return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); } if (!type.isIntOrIndex()) @@ -1996,7 +1995,7 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy) { - std::vector floatValues; + std::vector floatValues; floatValues.reserve(storage.size()); for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; @@ -2014,10 +2013,10 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, p.emitError("hexadecimal float constant out of range for attribute"); return nullptr; } - FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); - if (!attr) + Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); + if (!apVal) return nullptr; - floatValues.push_back(attr); + floatValues.push_back(*apVal); continue; } @@ -2033,7 +2032,14 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, p.emitError("floating point value too large for attribute"); return nullptr; } - floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); + // Treat BF16 as double because it is not supported in LLVM's APFloat. + APFloat apVal(isNegative ? -*val : *val); + if (!eltTy.isBF16() && !eltTy.isF64()) { + bool unused; + apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + } + floatValues.push_back(apVal); } return DenseElementsAttr::get(type, floatValues);