forked from OSchip/llvm-project
[mlir][sparse] migrate sparse rewriting to sparse transformations pass
The rules in the linalg file were very specific to sparse tensors so will find a better home under sparse tensor dialect than linalg dialect. Also moved some rewriting from sparsification into this new "pre-rewriting" file. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D129910
This commit is contained in:
parent
2570f226d1
commit
28ebb0b61d
|
@ -134,12 +134,19 @@ void populateSparseTensorConversionPatterns(
|
|||
const SparseTensorConversionOptions &options =
|
||||
SparseTensorConversionOptions());
|
||||
|
||||
std::unique_ptr<Pass> createDenseBufferizationPass(
|
||||
const bufferization::OneShotBufferizationOptions &options);
|
||||
std::unique_ptr<Pass> createSparseTensorConversionPass();
|
||||
std::unique_ptr<Pass>
|
||||
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other rewriting rules and passes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populateSparseTensorRewriting(RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createDenseBufferizationPass(
|
||||
const bufferization::OneShotBufferizationOptions &options);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,7 +22,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
LinalgStrategyPasses.cpp
|
||||
NamedOpConversions.cpp
|
||||
Promotion.cpp
|
||||
SparseTensorRewriting.cpp
|
||||
Split.cpp
|
||||
SplitReduction.cpp
|
||||
Tiling.cpp
|
||||
|
|
|
@ -1717,12 +1717,8 @@ struct LinalgElementwiseOpFusionPass
|
|||
|
||||
// Add elementwise op fusion patterns.
|
||||
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
|
||||
|
||||
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
|
||||
|
||||
// Add the sparse tensor rewriting patterns.
|
||||
populateSparseTensorRewriting(patterns);
|
||||
|
||||
// General canonicalization patterns.
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
GenericOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
|
|
@ -52,11 +52,6 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
|||
OpPassManager &pm, const SparseCompilerOptions &options) {
|
||||
// TODO(wrengr): ensure the original `pm` is for ModuleOp
|
||||
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
|
||||
// TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
|
||||
// fit well with bufferization because it replaces unused "out" operands of
|
||||
// LinalgOps with InitTensorOps. This would result in additional buffer
|
||||
// allocations during bufferization.
|
||||
// pm.addPass(createLinalgElementwiseOpFusionPass());
|
||||
pm.addPass(
|
||||
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
|
||||
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
||||
|
|
|
@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
|
|||
Sparsification.cpp
|
||||
SparseTensorConversion.cpp
|
||||
SparseTensorPasses.cpp
|
||||
SparseTensorRewriting.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
|
||||
|
|
|
@ -49,13 +49,17 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
|
|||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
// Apply pre-rewriting.
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populateSparseTensorRewriting(prePatterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
|
||||
// Translate strategy flags to strategy options.
|
||||
SparsificationOptions options(
|
||||
sparseParallelizationStrategy(parallelization),
|
||||
sparseVectorizationStrategy(vectorization), vectorLength,
|
||||
enableSIMDIndex32, enableVLAVectorization);
|
||||
// Apply rewriting.
|
||||
// Apply sparsification and vector cleanup rewriting.
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparsificationPatterns(patterns, options);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
|
|
@ -6,20 +6,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements linalg dialect rewriting specific to sparse tensors.
|
||||
//
|
||||
// Sparsity should be mostly transparent to the linalg dialect optimizations
|
||||
// (i.e., the dense and sparse take the same path). However, in some cases,
|
||||
// optimizations only make sense in the context of sparse tensors. This file
|
||||
// implements such sparsity specific rewriting rules.
|
||||
// This file implements rewriting rules that are specific to sparse tensors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -98,6 +94,7 @@ static bool isSumOfMul(GenericOp op) {
|
|||
//===---------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// Rewriting rule that converts two kernels:
|
||||
///
|
||||
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
|
||||
|
@ -114,6 +111,7 @@ namespace {
|
|||
/// a fusion may actually reduce the asymptotic complexity of the kernel,
|
||||
/// since intermediate results may be nullified.
|
||||
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
|
||||
public:
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp op,
|
||||
|
@ -194,13 +192,55 @@ private:
|
|||
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for reshape operator.
|
||||
template <typename ReshapeOp>
|
||||
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
|
||||
public:
|
||||
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
||||
// Since a pure dense expansion is very cheap (change of view), for
|
||||
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
|
||||
// conversion from the reshape operation itself.
|
||||
// All other cases are handled elsewhere.
|
||||
if (encDst && encSrc) {
|
||||
return failure();
|
||||
} else if (encSrc) {
|
||||
RankedTensorType rtp =
|
||||
op.getSrc().getType().template cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
op->setOperand(0, convert);
|
||||
return success();
|
||||
} else if (encDst) {
|
||||
RankedTensorType rtp =
|
||||
op.getResult().getType().template cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
|
||||
op.getReassociation());
|
||||
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
||||
rewriter.replaceOp(op, convert);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods that add patterns described in this file to a pattern list.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns.add<FuseSparseMultiplyOverAdd>(context);
|
||||
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
|
||||
// TODO(springerm): enable FuseSparseMultiplyOverAdd
|
||||
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
|
||||
}
|
|
@ -1802,46 +1802,6 @@ private:
|
|||
SparsificationOptions options;
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for reshape operator.
|
||||
template <typename ReshapeOp>
|
||||
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
|
||||
public:
|
||||
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
||||
// Since a pure dense expansion is very cheap (change of view), for
|
||||
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
|
||||
// conversion from the reshape operation itself.
|
||||
// All other cases are handled elsewhere.
|
||||
if (encDst && encSrc) {
|
||||
return failure();
|
||||
} else if (encSrc) {
|
||||
RankedTensorType rtp =
|
||||
op.getSrc().getType().template cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
op->setOperand(0, convert);
|
||||
return success();
|
||||
} else if (encDst) {
|
||||
RankedTensorType rtp =
|
||||
op.getResult().getType().template cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
|
||||
op.getReassociation());
|
||||
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
||||
rewriter.replaceOp(op, convert);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populates the given patterns list with rewriting rules required for
|
||||
|
@ -1849,6 +1809,4 @@ public:
|
|||
void mlir::populateSparsificationPatterns(
|
||||
RewritePatternSet &patterns, const SparsificationOptions &options) {
|
||||
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
|
||||
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -2115,6 +2115,7 @@ cc_library(
|
|||
":SparseTensorDialect",
|
||||
":SparseTensorPassIncGen",
|
||||
":SparseTensorUtils",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":Transforms",
|
||||
":VectorDialect",
|
||||
|
|
Loading…
Reference in New Issue