From c2415d67a5644eae870328497d46d97b28d5a974 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 26 Jul 2021 17:14:30 -0700 Subject: [PATCH] [mlir][sparse] fixed bug in verification The order of testing in two sparse tensor ops was incorrect, which could cause an invalid cast (crashing the compiler instead of reporting the error). This revision fixes that bug. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D106841 --- .../SparseTensor/IR/SparseTensorDialect.cpp | 8 ++++---- mlir/test/Dialect/SparseTensor/invalid.mlir | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index e07dfdcb7f0c..09cb88664d26 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -214,9 +214,9 @@ static LogicalResult verify(NewOp op) { } static LogicalResult verify(ToPointersOp op) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested pointers dimension out of bounds"); if (auto e = getSparseTensorEncoding(op.tensor().getType())) { + if (failed(isInBounds(op.dim(), op.tensor()))) + return op.emitError("requested pointers dimension out of bounds"); if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) return op.emitError("unexpected type for pointers"); return success(); @@ -225,9 +225,9 @@ static LogicalResult verify(ToPointersOp op) { } static LogicalResult verify(ToIndicesOp op) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested indices dimension out of bounds"); if (auto e = getSparseTensorEncoding(op.tensor().getType())) { + if (failed(isInBounds(op.dim(), op.tensor()))) + return op.emitError("requested indices dimension out of bounds"); if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) return op.emitError("unexpected type for indices"); return success(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 06a63cf37cf5..07cae4fc4a16 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -17,6 +17,15 @@ func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref { // ----- +func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref { + %c = constant 0 : index + // expected-error@+1 {{expected a sparse tensor to get pointers}} + %0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref + return %0 : memref +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"], pointerBitWidth=32}> func @mismatch_pointers_types(%arg0: tensor<128xf64, #SparseVector>) -> memref { @@ -48,6 +57,15 @@ func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref { // ----- +func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref { + %c = constant 0 : index + // expected-error@+1 {{expected a sparse tensor to get indices}} + %0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref + return %0 : memref +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func @mismatch_indices_types(%arg0: tensor) -> memref {