[mlir][sparse] Add rewrite rules for sparse-to-sparse reshape operators.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135077
This commit is contained in:
bixia1 2022-10-05 14:33:53 -07:00
parent 6b869be810
commit 330d48c4aa
2 changed files with 209 additions and 5 deletions
mlir
lib/Dialect/SparseTensor/Transforms
test/Dialect/SparseTensor

View File

@ -111,6 +111,20 @@ static bool isZeroYield(GenericOp op) {
return isZeroValue(yieldOp.getOperand(0)); return isZeroValue(yieldOp.getOperand(0));
} }
/// Populates given sizes array from type (for static sizes) and from
/// the tensor (for dynamic sizes).
static void sizesForTensor(OpBuilder &builder, SmallVector<Value, 4> &sizes,
Location loc, ShapedType stp, Value tensor) {
for (const auto &d : enumerate(stp.getShape())) {
Value dim;
if (d.value() == ShapedType::kDynamicSize)
dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
else
dim = constantIndex(builder, loc, d.value());
sizes.push_back(dim);
}
}
// TODO: The dim level property of the COO type relies on input tensors, the // TODO: The dim level property of the COO type relies on input tensors, the
// shape relies on the output tensor // shape relies on the output tensor
// Helpers to setup a COO type. // Helpers to setup a COO type.
@ -119,8 +133,11 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
auto rank = src.getRank(); auto rank = src.getRank();
SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dims; SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dims;
// An unordered and non-unique compressed dim at beginning. // An unordered and non-unique compressed dim at beginning unless the tensor
dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo); // is a 1D tensor.
if (rank > 1)
dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo);
// TODO: it is actually ordered at the level for ordered input. // TODO: it is actually ordered at the level for ordered input.
// Followed by unordered non-unique n-2 singleton levels. // Followed by unordered non-unique n-2 singleton levels.
std::fill_n(std::back_inserter(dims), rank - 2, std::fill_n(std::back_inserter(dims), rank - 2,
@ -281,7 +298,72 @@ private:
} }
}; };
/// Sparse rewriting rule for reshape operator. /// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc) {
return failure();
}
// Generate code to represent the static dimension constants or compute
// the dynamic dimension values.
SmallVector<Value, 4> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
SmallVector<Value, 4> dstSizes;
SmallVector<Value, 4> dstDynSizes;
if (dstTp.hasStaticShape()) {
for (auto d : dstTp.getShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
} else {
ArrayRef<int64_t> dstShape = dstTp.getShape();
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
for (auto &d : llvm::enumerate(dstShape)) {
if (d.value() == ShapedType::kDynamicSize)
dstDynSizes.push_back(dstSizes[d.index()]);
}
}
// Implement the sparse2sparse reshape as follows:
// %tmp = bufferization.alloc_tensor : unordered COO
// foreach srcCoords %srcTensor
// insert translateIndicesArray(srcCoords), %tmp
// %t = sparse_tensor.cast %tmp
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
rewriter.create<ForeachOp>(
loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) {
SmallVector<Value, 4> srcIndices;
SmallVector<Value, 4> dstIndices;
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
uint64_t dim = toStoredDim(encSrc, i);
srcIndices.push_back(args[dim]);
}
translateIndicesArray(builder, loc, op.getReassociationIndices(),
srcIndices, srcSizes, dstSizes, dstIndices);
builder.create<InsertOp>(loc, args.back(), cooBuffer, dstIndices);
builder.create<sparse_tensor::YieldOp>(loc);
});
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
return success();
}
};
/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
/// operator.
template <typename ReshapeOp> template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public: public:
@ -437,7 +519,6 @@ public:
//===---------------------------------------------------------------------===// //===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list. // Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===// //===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
bool enableRT) { bool enableRT) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
@ -446,5 +527,8 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
patterns.getContext()); patterns.getContext());
// TODO: If RT not enabled, rewrite concatenate ops, etc here. // TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) if (!enableRT)
patterns.add<ConcatenateRewriter>(patterns.getContext()); patterns.add<ConcatenateRewriter,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
patterns.getContext());
} }

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
@ -37,6 +38,29 @@
// CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
// //
// rewrite for codegen:
//
// CHECK-RWT-LABEL: func.func @sparse_expand(
// CHECK-RWT-SAME: %[[S:.*]]:
// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[C10]] : index
// CHECK-RWT: %[[DI1:.*]] = arith.remui %[[SI]], %[[C10]] : index
// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
// CHECK-RWT: }
// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] : %0 = tensor.expand_shape %arg0 [[0, 1]] :
tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix> tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
@ -76,6 +100,37 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
// //
// rewrite for codegen:
//
// CHECK-RWT-LABEL: func.func @sparse_collapse(
// CHECK-RWT-SAME: %[[S:.*]]:
// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
// CHECK-RWT: %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index
// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
// CHECK-RWT }
// CHECK-RWT: }
// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> { func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : %0 = tensor.collapse_shape %arg0 [[0, 1]] :
tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector> tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
@ -120,6 +175,35 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
// //
// rewrite for codegen:
//
// CHECK-RWT-LABEL: func.func @dynamic_sparse_expand(
// CHECK-RWT-SAME: %[[S:.*]]:
// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[SD:.*]] = tensor.dim %[[S]], %[[C0]]
// CHECK-RWT: %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
// CHECK-RWT: %[[T1:.*]] = arith.muli %[[DD0]], %[[C10]] : index
// CHECK-RWT: %[[T2:.*]] = arith.divui %[[T1]], %[[DD0]] : index
// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[T2]] : index
// CHECK-RWT: %[[T3:.*]] = arith.remui %[[SI]], %[[T2]] : index
// CHECK-RWT: %[[T4:.*]] = arith.divui %[[T2]], %[[C10]] : index
// CHECK-RWT: %[[DI1:.*]] = arith.divui %[[T3]], %[[T4]] : index
// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
// CHECK-RWT: }
// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
// CHECK-RWT: return %[[T]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> { func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] : %0 = tensor.expand_shape %arg0 [[0, 1]] :
tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix> tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
@ -163,6 +247,42 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-CONV: call @delSparseTensorCOOF64 // CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8> // CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
// //
// rewrite for codegen:
//
// CHECK-RWT-LABEL: func.func @dynamic_sparse_collapse(
// CHECK-RWT-SAME: %[[S:.*]]:
// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[SD1:.*]] = tensor.dim %[[S]], %[[C1]]
// CHECK-RWT: %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
// CHECK-RWT: %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index
// CHECK-RWT: %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index
// CHECK-RWT: %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index
// CHECK-RWT: %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index
// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index
// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
// CHECK-RWT }
// CHECK-RWT: }
// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
// CHECK-RWT: return %[[T]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> { func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : %0 = tensor.collapse_shape %arg0 [[0, 1]] :
tensor<10x?xf64, #SparseMatrix> into tensor<?xf64, #SparseVector> tensor<10x?xf64, #SparseMatrix> into tensor<?xf64, #SparseVector>