[mlir][sparse] fixed bug in verification

The order of testing in two sparse tensor ops was incorrect,
which could cause an invalid cast (crashing the compiler instead
of reporting the error). This revision fixes that bug.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D106841
This commit is contained in:
Aart Bik 2021-07-26 17:14:30 -07:00
parent b32d3d9e81
commit c2415d67a5
2 changed files with 22 additions and 4 deletions

View File

@ -214,9 +214,9 @@ static LogicalResult verify(NewOp op) {
}
static LogicalResult verify(ToPointersOp op) {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor())))
return op.emitError("requested pointers dimension out of bounds");
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
return op.emitError("unexpected type for pointers");
return success();
@ -225,9 +225,9 @@ static LogicalResult verify(ToPointersOp op) {
}
static LogicalResult verify(ToIndicesOp op) {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor())))
return op.emitError("requested indices dimension out of bounds");
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
return op.emitError("unexpected type for indices");
return success();

View File

@ -17,6 +17,15 @@ func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
// -----
func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = constant 0 : index
// expected-error@+1 {{expected a sparse tensor to get pointers}}
%0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"], pointerBitWidth=32}>
func @mismatch_pointers_types(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
@ -48,6 +57,15 @@ func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
// -----
func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = constant 0 : index
// expected-error@+1 {{expected a sparse tensor to get indices}}
%0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
func @mismatch_indices_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<?xi32> {