[mlir][sparse] migrate sparse rewriting to sparse transformations pass

The rules in the linalg file were very specific to sparse tensors so will
find a better home under sparse tensor dialect than linalg dialect. Also
moved some rewriting from sparsification into this new "pre-rewriting" file.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D129910
This commit is contained in:
Aart Bik 2022-07-15 16:41:02 -07:00
parent 2570f226d1
commit 28ebb0b61d
9 changed files with 67 additions and 66 deletions

View File

@ -134,12 +134,19 @@ void populateSparseTensorConversionPatterns(
const SparseTensorConversionOptions &options =
SparseTensorConversionOptions());
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
std::unique_ptr<Pass> createSparseTensorConversionPass();
std::unique_ptr<Pass>
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
void populateSparseTensorRewriting(RewritePatternSet &patterns);
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//

View File

@ -22,7 +22,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
Promotion.cpp
SparseTensorRewriting.cpp
Split.cpp
SplitReduction.cpp
Tiling.cpp

View File

@ -1717,12 +1717,8 @@ struct LinalgElementwiseOpFusionPass
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
// Add the sparse tensor rewriting patterns.
populateSparseTensorRewriting(patterns);
// General canonicalization patterns.
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);

View File

@ -52,11 +52,6 @@ void mlir::sparse_tensor::buildSparseCompiler(
OpPassManager &pm, const SparseCompilerOptions &options) {
// TODO(wrengr): ensure the original `pm` is for ModuleOp
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
// TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
// fit well with bufferization because it replaces unused "out" operands of
// LinalgOps with InitTensorOps. This would result in additional buffer
// allocations during bufferization.
// pm.addPass(createLinalgElementwiseOpFusionPass());
pm.addPass(
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));

View File

@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Sparsification.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

View File

@ -49,13 +49,17 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
// Apply pre-rewriting.
RewritePatternSet prePatterns(ctx);
populateSparseTensorRewriting(prePatterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
// Translate strategy flags to strategy options.
SparsificationOptions options(
sparseParallelizationStrategy(parallelization),
sparseVectorizationStrategy(vectorization), vectorLength,
enableSIMDIndex32, enableVLAVectorization);
// Apply rewriting.
// Apply sparsification and vector cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

View File

@ -6,20 +6,16 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg dialect rewriting specific to sparse tensors.
//
// Sparsity should be mostly transparent to the linalg dialect optimizations
// (i.e., the dense and sparse take the same path). However, in some cases,
// optimizations only make sense in the context of sparse tensors. This file
// implements such sparsity specific rewriting rules.
// This file implements rewriting rules that are specific to sparse tensors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
@ -98,6 +94,7 @@ static bool isSumOfMul(GenericOp op) {
//===---------------------------------------------------------------------===//
namespace {
/// Rewriting rule that converts two kernels:
///
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
@ -114,6 +111,7 @@ namespace {
/// a fusion may actually reduce the asymptotic complexity of the kernel,
/// since intermediate results may be nullified.
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
@ -194,13 +192,55 @@ private:
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
}
};
/// Sparse rewriting rule for reshape operator.
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(op.getResult().getType());
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
// Since a pure dense expansion is very cheap (change of view), for
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
// conversion from the reshape operation itself.
// All other cases are handled elsewhere.
if (encDst && encSrc) {
return failure();
} else if (encSrc) {
RankedTensorType rtp =
op.getSrc().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
op->setOperand(0, convert);
return success();
} else if (encDst) {
RankedTensorType rtp =
op.getResult().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
op.getReassociation());
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
}
return failure();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FuseSparseMultiplyOverAdd>(context);
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
// TODO(springerm): enable FuseSparseMultiplyOverAdd
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}

View File

@ -1802,46 +1802,6 @@ private:
SparsificationOptions options;
};
/// Sparse rewriting rule for reshape operator.
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(op.getResult().getType());
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
// Since a pure dense expansion is very cheap (change of view), for
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
// conversion from the reshape operation itself.
// All other cases are handled elsewhere.
if (encDst && encSrc) {
return failure();
} else if (encSrc) {
RankedTensorType rtp =
op.getSrc().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
op->setOperand(0, convert);
return success();
} else if (encDst) {
RankedTensorType rtp =
op.getResult().getType().template cast<RankedTensorType>();
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
op.getReassociation());
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
}
return failure();
}
};
} // namespace
/// Populates the given patterns list with rewriting rules required for
@ -1849,6 +1809,4 @@ public:
void mlir::populateSparsificationPatterns(
RewritePatternSet &patterns, const SparsificationOptions &options) {
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}

View File

@ -2115,6 +2115,7 @@ cc_library(
":SparseTensorDialect",
":SparseTensorPassIncGen",
":SparseTensorUtils",
":Support",
":TensorDialect",
":Transforms",
":VectorDialect",