forked from OSchip/llvm-project
[mlir][Parser] Use APFloat instead of FloatAttr when parsing DenseElementsAttrs.
Summary: DenseElementsAttr stores float values as raw bits internally, so creating attributes just to have them unwrapped is extremely inefficient. Differential Revision: https://reviews.llvm.org/D74818
This commit is contained in:
parent
6b6c96695c
commit
4a7364f1c2
|
@ -1734,22 +1734,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
|
|||
}
|
||||
|
||||
/// Construct a float attribute bitwise equivalent to the integer literal.
|
||||
static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
|
||||
uint64_t value) {
|
||||
static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
|
||||
uint64_t value) {
|
||||
// FIXME: bfloat is currently stored as a double internally because it doesn't
|
||||
// have valid APFloat semantics.
|
||||
if (type.isF64() || type.isBF16()) {
|
||||
APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
|
||||
return p->builder.getFloatAttr(type, apFloat);
|
||||
}
|
||||
if (type.isF64() || type.isBF16())
|
||||
return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
|
||||
|
||||
APInt apInt(type.getWidth(), value);
|
||||
if (apInt != value) {
|
||||
p->emitError("hexadecimal float constant out of range for type");
|
||||
return nullptr;
|
||||
return llvm::None;
|
||||
}
|
||||
APFloat apFloat(type.getFloatSemantics(), apInt);
|
||||
return p->builder.getFloatAttr(type, apFloat);
|
||||
return APFloat(type.getFloatSemantics(), apInt);
|
||||
}
|
||||
|
||||
/// Parse a decimal or a hexadecimal literal, which can be either an integer
|
||||
|
@ -1787,7 +1784,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
|
|||
}
|
||||
|
||||
// Construct a float attribute bitwise equivalent to the integer literal.
|
||||
return buildHexadecimalFloatLiteral(this, floatType, *val);
|
||||
Optional<APFloat> apVal =
|
||||
buildHexadecimalFloatLiteral(this, floatType, *val);
|
||||
return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
|
||||
}
|
||||
|
||||
if (!type.isIntOrIndex())
|
||||
|
@ -1996,7 +1995,7 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
|
|||
DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
|
||||
ShapedType type,
|
||||
FloatType eltTy) {
|
||||
std::vector<Attribute> floatValues;
|
||||
std::vector<APFloat> floatValues;
|
||||
floatValues.reserve(storage.size());
|
||||
for (const auto &signAndToken : storage) {
|
||||
bool isNegative = signAndToken.first;
|
||||
|
@ -2014,10 +2013,10 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
|
|||
p.emitError("hexadecimal float constant out of range for attribute");
|
||||
return nullptr;
|
||||
}
|
||||
FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
|
||||
if (!attr)
|
||||
Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
|
||||
if (!apVal)
|
||||
return nullptr;
|
||||
floatValues.push_back(attr);
|
||||
floatValues.push_back(*apVal);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2033,7 +2032,14 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
|
|||
p.emitError("floating point value too large for attribute");
|
||||
return nullptr;
|
||||
}
|
||||
floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
|
||||
// Treat BF16 as double because it is not supported in LLVM's APFloat.
|
||||
APFloat apVal(isNegative ? -*val : *val);
|
||||
if (!eltTy.isBF16() && !eltTy.isF64()) {
|
||||
bool unused;
|
||||
apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
|
||||
&unused);
|
||||
}
|
||||
floatValues.push_back(apVal);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, floatValues);
|
||||
|
|
Loading…
Reference in New Issue