[mlir][sparse] refine bufferization allocation lowering

Marking bufferization allocation operation as invalid
during sparse lowering is too strict, since dense and
sparse allocation can co-exist. This revision refines
the lowering with a dynamic type check.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D128305
This commit is contained in:
Aart Bik 2022-06-21 14:13:14 -07:00
parent 8b8d126598
commit fde04aee33
2 changed files with 18 additions and 4 deletions

View File

@ -92,9 +92,9 @@ struct SparseTensorConversionPass
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
// All dynamic rules below accept new function, call, return, and tensor
// dim and cast operations as legal output of the rewriting provided that
// all sparse tensor types have been fully rewritten.
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@ -110,6 +110,10 @@ struct SparseTensorConversionPass
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
return converter.isLegal(op.getOperand().getType());
});
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
[&](bufferization::AllocTensorOp op) {
return converter.isLegal(op.getType());
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
@ -119,7 +123,6 @@ struct SparseTensorConversionPass
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
memref::MemRefDialect, scf::SCFDialect>();
target.addIllegalOp<bufferization::AllocTensorOp>();
// Translate strategy flags to strategy options.
SparseTensorConversionOptions options(
sparseToSparseConversionStrategy(sparseToSparse));

View File

@ -572,3 +572,14 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
return
}
// CHECK-LABEL: func @sparse_and_dense_init(
// CHECK: %[[S:.*]] = call @newSparseTensor
// CHECK: %[[D:.*]] = bufferization.alloc_tensor
// CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
-> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
%1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
}