[mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.

Depending on where the 0 dimension is within the shape, the parser will currently reject .mlir generated by the printer.

Differential Revision: https://reviews.llvm.org/D83445
This commit is contained in:
River Riddle 2020-07-08 17:44:27 -07:00
parent 00068c452a
commit 24aa4efffd
3 changed files with 57 additions and 40 deletions

View File

@ -1432,10 +1432,12 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
os << "sparse<";
printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
/*allowHex=*/false);
os << ", ";
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
DenseIntElementsAttr indices = elementsAttr.getIndices();
if (indices.getNumElements() != 0) {
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
os << ", ";
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
}
os << '>';
break;
}
@ -1476,20 +1478,15 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
// Special case for degenerate tensors.
auto numElements = type.getNumElements();
int64_t rank = type.getRank();
if (numElements == 0) {
for (int i = 0; i < rank; ++i)
os << '[';
for (int i = 0; i < rank; ++i)
os << ']';
if (numElements == 0)
return;
}
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
// The mixed-radix counter, with radices in 'shape'.
int64_t rank = type.getRank();
SmallVector<unsigned, 4> counter(rank, 0);
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;

View File

@ -758,13 +758,13 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
// Parse the literal data.
// Parse the literal data if necessary.
TensorLiteralParser literalParser(*this);
if (literalParser.parse(/*allowHex=*/true))
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
if (!consumeIf(Token::greater)) {
if (literalParser.parse(/*allowHex=*/true) ||
parseToken(Token::greater, "expected '>'"))
return nullptr;
}
auto typeLoc = getToken().getLoc();
auto type = parseElementsLiteralType(attrType);
@ -841,6 +841,25 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
// Check for the case where all elements are sparse. The indices are
// represented by a 2-dimensional shape where the second dimension is the rank
// of the type.
Type indiceEltType = builder.getIntegerType(64);
if (consumeIf(Token::greater)) {
ShapedType type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
// Construct the sparse elements attr using zero element indice/value
// attributes.
ShapedType indicesType =
RankedTensorType::get({0, type.getRank()}, indiceEltType);
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
return SparseElementsAttr::get(
type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
}
/// Parse the indices. We don't allow hex values here as we may need to use
/// the inferred shape.
auto indicesLoc = getToken().getLoc();
@ -869,7 +888,6 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// 2-dimensional shape where the second dimension is the rank of the type.
// Given that the parsed indices is a splat, we know that we only have one
// indice and thus one for the first dimension.
auto indiceEltType = builder.getIntegerType(64);
ShapedType indicesType;
if (indiceParser.getShape().empty()) {
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);

View File

@ -670,19 +670,21 @@ func @densetensorattr() -> () {
// CHECK: "fooi67"() {bar = dense<{{\[\[\[}}-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
"fooi67"(){bar = dense<[[[-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
// CHECK: "foo2"() {bar = dense<[]> : tensor<0xi32>} : () -> ()
"foo2"(){bar = dense<[]> : tensor<0xi32>} : () -> ()
// CHECK: "foo2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xi32>} : () -> ()
"foo2"(){bar = dense<[[]]> : tensor<1x0xi32>} : () -> ()
// CHECK: "foo2"() {bar = dense<> : tensor<0xi32>} : () -> ()
"foo2"(){bar = dense<> : tensor<0xi32>} : () -> ()
// CHECK: "foo2"() {bar = dense<> : tensor<1x0xi32>} : () -> ()
"foo2"(){bar = dense<> : tensor<1x0xi32>} : () -> ()
// CHECK: dense<> : tensor<0x512x512xi32>
"foo2"(){bar = dense<> : tensor<0x512x512xi32>} : () -> ()
// CHECK: "foo3"() {bar = dense<{{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
"foo3"(){bar = dense<[[[5, -6, 1, 2]], [[7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
// CHECK: "float1"() {bar = dense<5.000000e+00> : tensor<1x1x1xf32>} : () -> ()
"float1"(){bar = dense<[[[5.0]]]> : tensor<1x1x1xf32>} : () -> ()
// CHECK: "float2"() {bar = dense<[]> : tensor<0xf32>} : () -> ()
"float2"(){bar = dense<[]> : tensor<0xf32>} : () -> ()
// CHECK: "float2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xf32>} : () -> ()
"float2"(){bar = dense<[[]]> : tensor<1x0xf32>} : () -> ()
// CHECK: "float2"() {bar = dense<> : tensor<0xf32>} : () -> ()
"float2"(){bar = dense<> : tensor<0xf32>} : () -> ()
// CHECK: "float2"() {bar = dense<> : tensor<1x0xf32>} : () -> ()
"float2"(){bar = dense<> : tensor<1x0xf32>} : () -> ()
// CHECK: "bfloat16"() {bar = dense<{{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]> : tensor<2x1x4xbf16>} : () -> ()
"bfloat16"(){bar = dense<[[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]> : tensor<2x1x4xbf16>} : () -> ()
@ -752,27 +754,27 @@ func @sparsetensorattr() -> () {
"fooi8"(){bar = sparse<0, -2> : tensor<1x1x1xi8>} : () -> ()
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x1xi32>} : () -> ()
"fooi32"(){bar = sparse<[], []> : tensor<1x1xi32>} : () -> ()
// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> ()
"fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> ()
// CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> ()
"fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> ()
// CHECK: "foo2"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xi32>} : () -> ()
"foo2"(){bar = sparse<[], []> : tensor<0xi32>} : () -> ()
// CHECK: "foo3"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<i32>} : () -> ()
"foo3"(){bar = sparse<[], []> : tensor<i32>} : () -> ()
// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> ()
"foo2"(){bar = sparse<> : tensor<0xi32>} : () -> ()
// CHECK: "foo3"() {bar = sparse<> : tensor<i32>} : () -> ()
"foo3"(){bar = sparse<> : tensor<i32>} : () -> ()
// CHECK: "foof16"() {bar = sparse<0, -2.000000e+00> : tensor<1x1x1xf16>} : () -> ()
"foof16"(){bar = sparse<0, -2.0> : tensor<1x1x1xf16>} : () -> ()
// CHECK: "foobf16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]> : tensor<2x2x2xbf16>} : () -> ()
"foobf16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]> : tensor<2x2x2xbf16>} : () -> ()
// CHECK: "foof32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x0x1xf32>} : () -> ()
"foof32"(){bar = sparse<[], []> : tensor<1x0x1xf32>} : () -> ()
// CHECK: "foof32"() {bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
"foof32"(){bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
// CHECK: "foof64"() {bar = sparse<0, -1.000000e+00> : tensor<1xf64>} : () -> ()
"foof64"(){bar = sparse<[[0]], [-1.0]> : tensor<1xf64>} : () -> ()
// CHECK: "foof320"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xf32>} : () -> ()
"foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
"foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
// CHECK: "foof320"() {bar = sparse<> : tensor<0xf32>} : () -> ()
"foof320"(){bar = sparse<> : tensor<0xf32>} : () -> ()
// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
"foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
"foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
@ -789,8 +791,8 @@ func @sparsevectorattr() -> () {
"fooi8"(){bar = sparse<0, -2> : vector<1x1x1xi8>} : () -> ()
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : vector<1x1xi32>} : () -> ()
"fooi32"(){bar = sparse<[], []> : vector<1x1xi32>} : () -> ()
// CHECK: "fooi32"() {bar = sparse<> : vector<1x1xi32>} : () -> ()
"fooi32"(){bar = sparse<> : vector<1x1xi32>} : () -> ()
// CHECK: "fooi64"() {bar = sparse<0, -1> : vector<1xi64>} : () -> ()
"fooi64"(){bar = sparse<[[0]], [-1]> : vector<1xi64>} : () -> ()