diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 5bbb4a5e1dda..66b92ec3b360 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -92,9 +92,9 @@ struct SparseTensorConversionPass ConversionTarget target(*ctx); // Everything in the sparse dialect must go! target.addIllegalDialect(); - // All dynamic rules below accept new function, call, return, and tensor - // dim and cast operations as legal output of the rewriting provided that - // all sparse tensor types have been fully rewritten. + // All dynamic rules below accept new function, call, return, and various + // tensor and bufferization operations as legal output of the rewriting + // provided that all sparse tensor types have been fully rewritten. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); @@ -110,6 +110,10 @@ struct SparseTensorConversionPass target.addDynamicallyLegalOp([&](tensor::CastOp op) { return converter.isLegal(op.getOperand().getType()); }); + target.addDynamicallyLegalOp( + [&](bufferization::AllocTensorOp op) { + return converter.isLegal(op.getType()); + }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. target.addLegalOp(); - target.addIllegalOp(); // Translate strategy flags to strategy options. SparseTensorConversionOptions options( sparseToSparseConversionStrategy(sparseToSparse)); diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir index 2a85d012b98e..d9b3ed1c2a3b 100644 --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -572,3 +572,14 @@ func.func @sparse_out2(%arg0: tensor, %arg1: !llvm.ptr sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr return } + +// CHECK-LABEL: func @sparse_and_dense_init( +// CHECK: %[[S:.*]] = call @newSparseTensor +// CHECK: %[[D:.*]] = bufferization.alloc_tensor +// CHECK: return %[[S]], %[[D]] : !llvm.ptr, tensor +func.func @sparse_and_dense_init(%arg0: index, %arg1: index) + -> (tensor, tensor) { + %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + return %0, %1 : tensor, tensor +}