forked from OSchip/llvm-project
[mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.
Depending on where the 0 dimension is within the shape, the parser will currently reject .mlir generated by the printer. Differential Revision: https://reviews.llvm.org/D83445
This commit is contained in:
parent
00068c452a
commit
24aa4efffd
|
@ -1432,10 +1432,12 @@ void ModulePrinter::printAttribute(Attribute attr,
|
|||
break;
|
||||
}
|
||||
os << "sparse<";
|
||||
printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
|
||||
/*allowHex=*/false);
|
||||
os << ", ";
|
||||
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
|
||||
DenseIntElementsAttr indices = elementsAttr.getIndices();
|
||||
if (indices.getNumElements() != 0) {
|
||||
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
|
||||
os << ", ";
|
||||
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
|
||||
}
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
|
@ -1476,20 +1478,15 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
|
|||
|
||||
// Special case for degenerate tensors.
|
||||
auto numElements = type.getNumElements();
|
||||
int64_t rank = type.getRank();
|
||||
if (numElements == 0) {
|
||||
for (int i = 0; i < rank; ++i)
|
||||
os << '[';
|
||||
for (int i = 0; i < rank; ++i)
|
||||
os << ']';
|
||||
if (numElements == 0)
|
||||
return;
|
||||
}
|
||||
|
||||
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
||||
// non-least-significant digit, we emit a close bracket. When we next emit an
|
||||
// element we re-open all closed brackets.
|
||||
|
||||
// The mixed-radix counter, with radices in 'shape'.
|
||||
int64_t rank = type.getRank();
|
||||
SmallVector<unsigned, 4> counter(rank, 0);
|
||||
// The number of brackets that have been opened and not closed.
|
||||
unsigned openBrackets = 0;
|
||||
|
|
|
@ -758,13 +758,13 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
|
|||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
|
||||
// Parse the literal data.
|
||||
// Parse the literal data if necessary.
|
||||
TensorLiteralParser literalParser(*this);
|
||||
if (literalParser.parse(/*allowHex=*/true))
|
||||
return nullptr;
|
||||
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
if (!consumeIf(Token::greater)) {
|
||||
if (literalParser.parse(/*allowHex=*/true) ||
|
||||
parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto typeLoc = getToken().getLoc();
|
||||
auto type = parseElementsLiteralType(attrType);
|
||||
|
@ -841,6 +841,25 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
|
|||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||
return nullptr;
|
||||
|
||||
// Check for the case where all elements are sparse. The indices are
|
||||
// represented by a 2-dimensional shape where the second dimension is the rank
|
||||
// of the type.
|
||||
Type indiceEltType = builder.getIntegerType(64);
|
||||
if (consumeIf(Token::greater)) {
|
||||
ShapedType type = parseElementsLiteralType(attrType);
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
// Construct the sparse elements attr using zero element indice/value
|
||||
// attributes.
|
||||
ShapedType indicesType =
|
||||
RankedTensorType::get({0, type.getRank()}, indiceEltType);
|
||||
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
|
||||
return SparseElementsAttr::get(
|
||||
type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
|
||||
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
|
||||
}
|
||||
|
||||
/// Parse the indices. We don't allow hex values here as we may need to use
|
||||
/// the inferred shape.
|
||||
auto indicesLoc = getToken().getLoc();
|
||||
|
@ -869,7 +888,6 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
|
|||
// 2-dimensional shape where the second dimension is the rank of the type.
|
||||
// Given that the parsed indices is a splat, we know that we only have one
|
||||
// indice and thus one for the first dimension.
|
||||
auto indiceEltType = builder.getIntegerType(64);
|
||||
ShapedType indicesType;
|
||||
if (indiceParser.getShape().empty()) {
|
||||
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
|
||||
|
|
|
@ -670,19 +670,21 @@ func @densetensorattr() -> () {
|
|||
// CHECK: "fooi67"() {bar = dense<{{\[\[\[}}-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
|
||||
"fooi67"(){bar = dense<[[[-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
|
||||
|
||||
// CHECK: "foo2"() {bar = dense<[]> : tensor<0xi32>} : () -> ()
|
||||
"foo2"(){bar = dense<[]> : tensor<0xi32>} : () -> ()
|
||||
// CHECK: "foo2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xi32>} : () -> ()
|
||||
"foo2"(){bar = dense<[[]]> : tensor<1x0xi32>} : () -> ()
|
||||
// CHECK: "foo2"() {bar = dense<> : tensor<0xi32>} : () -> ()
|
||||
"foo2"(){bar = dense<> : tensor<0xi32>} : () -> ()
|
||||
// CHECK: "foo2"() {bar = dense<> : tensor<1x0xi32>} : () -> ()
|
||||
"foo2"(){bar = dense<> : tensor<1x0xi32>} : () -> ()
|
||||
// CHECK: dense<> : tensor<0x512x512xi32>
|
||||
"foo2"(){bar = dense<> : tensor<0x512x512xi32>} : () -> ()
|
||||
// CHECK: "foo3"() {bar = dense<{{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
|
||||
"foo3"(){bar = dense<[[[5, -6, 1, 2]], [[7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
|
||||
|
||||
// CHECK: "float1"() {bar = dense<5.000000e+00> : tensor<1x1x1xf32>} : () -> ()
|
||||
"float1"(){bar = dense<[[[5.0]]]> : tensor<1x1x1xf32>} : () -> ()
|
||||
// CHECK: "float2"() {bar = dense<[]> : tensor<0xf32>} : () -> ()
|
||||
"float2"(){bar = dense<[]> : tensor<0xf32>} : () -> ()
|
||||
// CHECK: "float2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xf32>} : () -> ()
|
||||
"float2"(){bar = dense<[[]]> : tensor<1x0xf32>} : () -> ()
|
||||
// CHECK: "float2"() {bar = dense<> : tensor<0xf32>} : () -> ()
|
||||
"float2"(){bar = dense<> : tensor<0xf32>} : () -> ()
|
||||
// CHECK: "float2"() {bar = dense<> : tensor<1x0xf32>} : () -> ()
|
||||
"float2"(){bar = dense<> : tensor<1x0xf32>} : () -> ()
|
||||
|
||||
// CHECK: "bfloat16"() {bar = dense<{{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]> : tensor<2x1x4xbf16>} : () -> ()
|
||||
"bfloat16"(){bar = dense<[[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]> : tensor<2x1x4xbf16>} : () -> ()
|
||||
|
@ -752,27 +754,27 @@ func @sparsetensorattr() -> () {
|
|||
"fooi8"(){bar = sparse<0, -2> : tensor<1x1x1xi8>} : () -> ()
|
||||
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
|
||||
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
|
||||
// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x1xi32>} : () -> ()
|
||||
"fooi32"(){bar = sparse<[], []> : tensor<1x1xi32>} : () -> ()
|
||||
// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> ()
|
||||
"fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> ()
|
||||
// CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> ()
|
||||
"fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> ()
|
||||
// CHECK: "foo2"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xi32>} : () -> ()
|
||||
"foo2"(){bar = sparse<[], []> : tensor<0xi32>} : () -> ()
|
||||
// CHECK: "foo3"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<i32>} : () -> ()
|
||||
"foo3"(){bar = sparse<[], []> : tensor<i32>} : () -> ()
|
||||
// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> ()
|
||||
"foo2"(){bar = sparse<> : tensor<0xi32>} : () -> ()
|
||||
// CHECK: "foo3"() {bar = sparse<> : tensor<i32>} : () -> ()
|
||||
"foo3"(){bar = sparse<> : tensor<i32>} : () -> ()
|
||||
|
||||
// CHECK: "foof16"() {bar = sparse<0, -2.000000e+00> : tensor<1x1x1xf16>} : () -> ()
|
||||
"foof16"(){bar = sparse<0, -2.0> : tensor<1x1x1xf16>} : () -> ()
|
||||
// CHECK: "foobf16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]> : tensor<2x2x2xbf16>} : () -> ()
|
||||
"foobf16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]> : tensor<2x2x2xbf16>} : () -> ()
|
||||
// CHECK: "foof32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x0x1xf32>} : () -> ()
|
||||
"foof32"(){bar = sparse<[], []> : tensor<1x0x1xf32>} : () -> ()
|
||||
// CHECK: "foof32"() {bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
|
||||
"foof32"(){bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
|
||||
// CHECK: "foof64"() {bar = sparse<0, -1.000000e+00> : tensor<1xf64>} : () -> ()
|
||||
"foof64"(){bar = sparse<[[0]], [-1.0]> : tensor<1xf64>} : () -> ()
|
||||
// CHECK: "foof320"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xf32>} : () -> ()
|
||||
"foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
|
||||
// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
|
||||
"foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
|
||||
// CHECK: "foof320"() {bar = sparse<> : tensor<0xf32>} : () -> ()
|
||||
"foof320"(){bar = sparse<> : tensor<0xf32>} : () -> ()
|
||||
// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
|
||||
"foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
|
||||
|
||||
// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
|
||||
"foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
|
||||
|
@ -789,8 +791,8 @@ func @sparsevectorattr() -> () {
|
|||
"fooi8"(){bar = sparse<0, -2> : vector<1x1x1xi8>} : () -> ()
|
||||
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
|
||||
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
|
||||
// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : vector<1x1xi32>} : () -> ()
|
||||
"fooi32"(){bar = sparse<[], []> : vector<1x1xi32>} : () -> ()
|
||||
// CHECK: "fooi32"() {bar = sparse<> : vector<1x1xi32>} : () -> ()
|
||||
"fooi32"(){bar = sparse<> : vector<1x1xi32>} : () -> ()
|
||||
// CHECK: "fooi64"() {bar = sparse<0, -1> : vector<1xi64>} : () -> ()
|
||||
"fooi64"(){bar = sparse<[[0]], [-1]> : vector<1xi64>} : () -> ()
|
||||
|
||||
|
|
Loading…
Reference in New Issue