[mlir][sparse] remove a few rewriting failures

Rationale:
Make sure preconditions are tested already during verfication.
Currently, the only way a sparse rewriting rule can fail is if
(1) the linalg op does not have sparse annotations, or
(2) a yet to be handled operation is encounted inside the op

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D91748
This commit is contained in:
Aart Bik 2020-11-18 15:35:57 -08:00
parent 803af31e5b
commit 9ad62f62b9
3 changed files with 85 additions and 19 deletions

View File

@ -419,6 +419,8 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
// Verify consistency of sparse annotations. // Verify consistency of sparse annotations.
if (!op.hasTensorSemantics()) if (!op.hasTensorSemantics())
return op.emitOpError("expected sparse annotations on tensors only"); return op.emitOpError("expected sparse annotations on tensors only");
if (op.getNumOutputs() != 1)
return op.emitOpError("expected single output tensor");
unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numTensors = op.getNumInputsAndOutputs();
if (sparseAttr.size() != numTensors) if (sparseAttr.size() != numTensors)
return op.emitOpError("expected one sparse annotation for each tensor"); return op.emitOpError("expected one sparse annotation for each tensor");

View File

@ -830,22 +830,16 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp op, LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
// Detects sparse annotations and translate the per-dimension sparsity // Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel. // information for all tensors to loop indices in the kernel.
if (!op.hasSparseSemantics()) if (!op.hasSparseSemantics())
return failure(); return failure();
assert(op.getNumOutputs() == 1);
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
findSparseAnnotations(op, merger.sparse()); findSparseAnnotations(op, merger.sparse());
// Accept only single, dense result.
if (op.getNumOutputs() != 1 ||
std::any_of(merger.sparse().back().begin(),
merger.sparse().back().end(), [](bool b) { return b; }))
return failure();
// Computes a topologically sorted iteration graph to ensure // Computes a topologically sorted iteration graph to ensure
// tensors are visited in natural index order. Fails on cycles. // tensors are visited in natural index order. Fails on cycles.
// This assumes that higher-level passes have already put the // This assumes that higher-level passes have already put the
@ -858,10 +852,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Finds the terminating yield statement and builds the tensor // Finds the terminating yield statement and builds the tensor
// expression for the Linalg operation in SSA form. // expression for the Linalg operation in SSA form.
auto &region = op.region(); Operation *yield = op.region().front().getTerminator();
if (!llvm::hasSingleElement(region))
return failure(); // single block only
Operation *yield = region.front().getTerminator();
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
if (!exp.hasValue()) if (!exp.hasValue())
return failure(); // build failure return failure(); // build failure

View File

@ -25,6 +25,79 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
// ----- // -----
#trait_two_out = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)>, // x (out)
affine_map<(i) -> (i)> // y (out)
],
sparse = [
[ "S" ], // a
[ "D" ], // x
[ "D" ] // y
],
iterator_types = ["parallel"]
}
func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> {
// expected-error@+1 {{'linalg.generic' op expected single output tensor}}
%0, %1 = linalg.generic #trait_two_out
ins(%arga: tensor<32xf32>) {
^bb(%a: f32):
%0 = addf %a, %a : f32
linalg.yield %a, %0 : f32, f32
} -> tensor<32xf32>, tensor<32xf32>
return %1 : tensor<32xf32>
}
// -----
#trait_two_blocks = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)> // x (out)
],
sparse = [
[ "S" ], // a
[ "D" ] // x
],
iterator_types = ["parallel"]
}
func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> {
// expected-error@+1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}}
%0 = linalg.generic #trait_two_blocks
ins(%arga: tensor<32xf32>) {
^bb1(%a: f32):
%0 = addf %a, %a : f32
^bb2:
linalg.yield %0 : f32
} -> tensor<32xf32>
return %0 : tensor<32xf32>
}
// -----
#trait_no_block = {
indexing_maps = [
affine_map<(i) -> (i)> // a
],
sparse = [
[ "S" ] // a
],
iterator_types = ["parallel"]
}
func @invalid_no_block(%arga: tensor<32xf32>) {
// expected-error@+1 {{'linalg.generic' op expected region with 1 block}}
linalg.generic #trait_no_block
ins(%arga: tensor<32xf32>) {
}
return
}
// -----
#trait_too_many = { #trait_too_many = {
indexing_maps = [ indexing_maps = [
affine_map<(i) -> (i)>, // a affine_map<(i) -> (i)>, // a