[mlir][DenseElementsAttr] Fix storage size for bfloat16 when parsing from hex.

Summary: bfloat16 is stored internally as a double, so we can't direct use Type::getIntOrFloatBitWidth.

Differential Revision: https://reviews.llvm.org/D75133
This commit is contained in:
River Riddle 2020-02-25 14:56:24 -08:00
parent 4b2b8b96db
commit b3e6487f02
2 changed files with 8 additions and 1 deletions

View File

@ -2118,8 +2118,12 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
// Check that the size of the hex data correpsonds to the size of the type, or
// a splat of the type.
// TODO: bf16 is currently stored as a double, this should be removed when
// APFloat properly supports it.
int64_t elementWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
(type.getNumElements() * elementType.getIntOrFloatBitWidth())) {
(type.getNumElements() * elementWidth)) {
p.emitError(loc) << "elements hex data size is invalid for provided type: "
<< type;
return nullptr;

View File

@ -7,6 +7,9 @@
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()
// -----
// expected-error@+1 {{elements hex string should start with '0x'}}