forked from OSchip/llvm-project
[mlir][sparse] remove a few rewriting failures
Rationale: Make sure preconditions are tested already during verfication. Currently, the only way a sparse rewriting rule can fail is if (1) the linalg op does not have sparse annotations, or (2) a yet to be handled operation is encounted inside the op Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D91748
This commit is contained in:
parent
803af31e5b
commit
9ad62f62b9
|
@ -419,6 +419,8 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
|
||||||
// Verify consistency of sparse annotations.
|
// Verify consistency of sparse annotations.
|
||||||
if (!op.hasTensorSemantics())
|
if (!op.hasTensorSemantics())
|
||||||
return op.emitOpError("expected sparse annotations on tensors only");
|
return op.emitOpError("expected sparse annotations on tensors only");
|
||||||
|
if (op.getNumOutputs() != 1)
|
||||||
|
return op.emitOpError("expected single output tensor");
|
||||||
unsigned numTensors = op.getNumInputsAndOutputs();
|
unsigned numTensors = op.getNumInputsAndOutputs();
|
||||||
if (sparseAttr.size() != numTensors)
|
if (sparseAttr.size() != numTensors)
|
||||||
return op.emitOpError("expected one sparse annotation for each tensor");
|
return op.emitOpError("expected one sparse annotation for each tensor");
|
||||||
|
|
|
@ -830,22 +830,16 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
LogicalResult matchAndRewrite(linalg::GenericOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
unsigned numTensors = op.getNumInputsAndOutputs();
|
|
||||||
unsigned numLoops = op.iterator_types().getValue().size();
|
|
||||||
Merger merger(numTensors, numLoops);
|
|
||||||
|
|
||||||
// Detects sparse annotations and translate the per-dimension sparsity
|
// Detects sparse annotations and translate the per-dimension sparsity
|
||||||
// information for all tensors to loop indices in the kernel.
|
// information for all tensors to loop indices in the kernel.
|
||||||
if (!op.hasSparseSemantics())
|
if (!op.hasSparseSemantics())
|
||||||
return failure();
|
return failure();
|
||||||
|
assert(op.getNumOutputs() == 1);
|
||||||
|
unsigned numTensors = op.getNumInputsAndOutputs();
|
||||||
|
unsigned numLoops = op.iterator_types().getValue().size();
|
||||||
|
Merger merger(numTensors, numLoops);
|
||||||
findSparseAnnotations(op, merger.sparse());
|
findSparseAnnotations(op, merger.sparse());
|
||||||
|
|
||||||
// Accept only single, dense result.
|
|
||||||
if (op.getNumOutputs() != 1 ||
|
|
||||||
std::any_of(merger.sparse().back().begin(),
|
|
||||||
merger.sparse().back().end(), [](bool b) { return b; }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Computes a topologically sorted iteration graph to ensure
|
// Computes a topologically sorted iteration graph to ensure
|
||||||
// tensors are visited in natural index order. Fails on cycles.
|
// tensors are visited in natural index order. Fails on cycles.
|
||||||
// This assumes that higher-level passes have already put the
|
// This assumes that higher-level passes have already put the
|
||||||
|
@ -858,10 +852,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
|
||||||
|
|
||||||
// Finds the terminating yield statement and builds the tensor
|
// Finds the terminating yield statement and builds the tensor
|
||||||
// expression for the Linalg operation in SSA form.
|
// expression for the Linalg operation in SSA form.
|
||||||
auto ®ion = op.region();
|
Operation *yield = op.region().front().getTerminator();
|
||||||
if (!llvm::hasSingleElement(region))
|
|
||||||
return failure(); // single block only
|
|
||||||
Operation *yield = region.front().getTerminator();
|
|
||||||
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
|
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
|
||||||
if (!exp.hasValue())
|
if (!exp.hasValue())
|
||||||
return failure(); // build failure
|
return failure(); // build failure
|
||||||
|
|
|
@ -25,6 +25,79 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#trait_two_out = {
|
||||||
|
indexing_maps = [
|
||||||
|
affine_map<(i) -> (i)>, // a
|
||||||
|
affine_map<(i) -> (i)>, // x (out)
|
||||||
|
affine_map<(i) -> (i)> // y (out)
|
||||||
|
],
|
||||||
|
sparse = [
|
||||||
|
[ "S" ], // a
|
||||||
|
[ "D" ], // x
|
||||||
|
[ "D" ] // y
|
||||||
|
],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
}
|
||||||
|
|
||||||
|
func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> {
|
||||||
|
// expected-error@+1 {{'linalg.generic' op expected single output tensor}}
|
||||||
|
%0, %1 = linalg.generic #trait_two_out
|
||||||
|
ins(%arga: tensor<32xf32>) {
|
||||||
|
^bb(%a: f32):
|
||||||
|
%0 = addf %a, %a : f32
|
||||||
|
linalg.yield %a, %0 : f32, f32
|
||||||
|
} -> tensor<32xf32>, tensor<32xf32>
|
||||||
|
return %1 : tensor<32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#trait_two_blocks = {
|
||||||
|
indexing_maps = [
|
||||||
|
affine_map<(i) -> (i)>, // a
|
||||||
|
affine_map<(i) -> (i)> // x (out)
|
||||||
|
],
|
||||||
|
sparse = [
|
||||||
|
[ "S" ], // a
|
||||||
|
[ "D" ] // x
|
||||||
|
],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
}
|
||||||
|
|
||||||
|
func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> {
|
||||||
|
// expected-error@+1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}}
|
||||||
|
%0 = linalg.generic #trait_two_blocks
|
||||||
|
ins(%arga: tensor<32xf32>) {
|
||||||
|
^bb1(%a: f32):
|
||||||
|
%0 = addf %a, %a : f32
|
||||||
|
^bb2:
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<32xf32>
|
||||||
|
return %0 : tensor<32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#trait_no_block = {
|
||||||
|
indexing_maps = [
|
||||||
|
affine_map<(i) -> (i)> // a
|
||||||
|
],
|
||||||
|
sparse = [
|
||||||
|
[ "S" ] // a
|
||||||
|
],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
}
|
||||||
|
|
||||||
|
func @invalid_no_block(%arga: tensor<32xf32>) {
|
||||||
|
// expected-error@+1 {{'linalg.generic' op expected region with 1 block}}
|
||||||
|
linalg.generic #trait_no_block
|
||||||
|
ins(%arga: tensor<32xf32>) {
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#trait_too_many = {
|
#trait_too_many = {
|
||||||
indexing_maps = [
|
indexing_maps = [
|
||||||
affine_map<(i) -> (i)>, // a
|
affine_map<(i) -> (i)>, // a
|
||||||
|
|
Loading…
Reference in New Issue