forked from OSchip/llvm-project
[mlir][sparse] refine bufferization allocation lowering
Marking bufferization allocation operation as invalid during sparse lowering is too strict, since dense and sparse allocation can co-exist. This revision refines the lowering with a dynamic type check. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D128305
This commit is contained in:
parent
8b8d126598
commit
fde04aee33
|
@ -92,9 +92,9 @@ struct SparseTensorConversionPass
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
// Everything in the sparse dialect must go!
|
// Everything in the sparse dialect must go!
|
||||||
target.addIllegalDialect<SparseTensorDialect>();
|
target.addIllegalDialect<SparseTensorDialect>();
|
||||||
// All dynamic rules below accept new function, call, return, and tensor
|
// All dynamic rules below accept new function, call, return, and various
|
||||||
// dim and cast operations as legal output of the rewriting provided that
|
// tensor and bufferization operations as legal output of the rewriting
|
||||||
// all sparse tensor types have been fully rewritten.
|
// provided that all sparse tensor types have been fully rewritten.
|
||||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||||
return converter.isSignatureLegal(op.getFunctionType());
|
return converter.isSignatureLegal(op.getFunctionType());
|
||||||
});
|
});
|
||||||
|
@ -110,6 +110,10 @@ struct SparseTensorConversionPass
|
||||||
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
|
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
|
||||||
return converter.isLegal(op.getOperand().getType());
|
return converter.isLegal(op.getOperand().getType());
|
||||||
});
|
});
|
||||||
|
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
||||||
|
[&](bufferization::AllocTensorOp op) {
|
||||||
|
return converter.isLegal(op.getType());
|
||||||
|
});
|
||||||
// The following operations and dialects may be introduced by the
|
// The following operations and dialects may be introduced by the
|
||||||
// rewriting rules, and are therefore marked as legal.
|
// rewriting rules, and are therefore marked as legal.
|
||||||
target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
|
target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
|
||||||
|
@ -119,7 +123,6 @@ struct SparseTensorConversionPass
|
||||||
target
|
target
|
||||||
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
|
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
|
||||||
memref::MemRefDialect, scf::SCFDialect>();
|
memref::MemRefDialect, scf::SCFDialect>();
|
||||||
target.addIllegalOp<bufferization::AllocTensorOp>();
|
|
||||||
// Translate strategy flags to strategy options.
|
// Translate strategy flags to strategy options.
|
||||||
SparseTensorConversionOptions options(
|
SparseTensorConversionOptions options(
|
||||||
sparseToSparseConversionStrategy(sparseToSparse));
|
sparseToSparseConversionStrategy(sparseToSparse));
|
||||||
|
|
|
@ -572,3 +572,14 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
|
||||||
sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
|
sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @sparse_and_dense_init(
|
||||||
|
// CHECK: %[[S:.*]] = call @newSparseTensor
|
||||||
|
// CHECK: %[[D:.*]] = bufferization.alloc_tensor
|
||||||
|
// CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
|
||||||
|
func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
|
||||||
|
-> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
|
||||||
|
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
|
||||||
|
%1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
|
||||||
|
return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue