forked from OSchip/llvm-project
[mlir][sparse] migrate sparse rewriting to sparse transformations pass
The rules in the linalg file were very specific to sparse tensors so will find a better home under sparse tensor dialect than linalg dialect. Also moved some rewriting from sparsification into this new "pre-rewriting" file. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D129910
This commit is contained in:
parent
2570f226d1
commit
28ebb0b61d
|
@ -134,12 +134,19 @@ void populateSparseTensorConversionPatterns(
|
||||||
const SparseTensorConversionOptions &options =
|
const SparseTensorConversionOptions &options =
|
||||||
SparseTensorConversionOptions());
|
SparseTensorConversionOptions());
|
||||||
|
|
||||||
std::unique_ptr<Pass> createDenseBufferizationPass(
|
|
||||||
const bufferization::OneShotBufferizationOptions &options);
|
|
||||||
std::unique_ptr<Pass> createSparseTensorConversionPass();
|
std::unique_ptr<Pass> createSparseTensorConversionPass();
|
||||||
std::unique_ptr<Pass>
|
std::unique_ptr<Pass>
|
||||||
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
|
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Other rewriting rules and passes.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void populateSparseTensorRewriting(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createDenseBufferizationPass(
|
||||||
|
const bufferization::OneShotBufferizationOptions &options);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Registration.
|
// Registration.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -22,7 +22,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||||
LinalgStrategyPasses.cpp
|
LinalgStrategyPasses.cpp
|
||||||
NamedOpConversions.cpp
|
NamedOpConversions.cpp
|
||||||
Promotion.cpp
|
Promotion.cpp
|
||||||
SparseTensorRewriting.cpp
|
|
||||||
Split.cpp
|
Split.cpp
|
||||||
SplitReduction.cpp
|
SplitReduction.cpp
|
||||||
Tiling.cpp
|
Tiling.cpp
|
||||||
|
|
|
@ -1717,12 +1717,8 @@ struct LinalgElementwiseOpFusionPass
|
||||||
|
|
||||||
// Add elementwise op fusion patterns.
|
// Add elementwise op fusion patterns.
|
||||||
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
|
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
|
||||||
|
|
||||||
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
|
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
|
||||||
|
|
||||||
// Add the sparse tensor rewriting patterns.
|
|
||||||
populateSparseTensorRewriting(patterns);
|
|
||||||
|
|
||||||
// General canonicalization patterns.
|
// General canonicalization patterns.
|
||||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||||
GenericOp::getCanonicalizationPatterns(patterns, context);
|
GenericOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
|
@ -52,11 +52,6 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
||||||
OpPassManager &pm, const SparseCompilerOptions &options) {
|
OpPassManager &pm, const SparseCompilerOptions &options) {
|
||||||
// TODO(wrengr): ensure the original `pm` is for ModuleOp
|
// TODO(wrengr): ensure the original `pm` is for ModuleOp
|
||||||
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
|
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
|
||||||
// TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
|
|
||||||
// fit well with bufferization because it replaces unused "out" operands of
|
|
||||||
// LinalgOps with InitTensorOps. This would result in additional buffer
|
|
||||||
// allocations during bufferization.
|
|
||||||
// pm.addPass(createLinalgElementwiseOpFusionPass());
|
|
||||||
pm.addPass(
|
pm.addPass(
|
||||||
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
|
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
|
||||||
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
||||||
|
|
|
@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
|
||||||
Sparsification.cpp
|
Sparsification.cpp
|
||||||
SparseTensorConversion.cpp
|
SparseTensorConversion.cpp
|
||||||
SparseTensorPasses.cpp
|
SparseTensorPasses.cpp
|
||||||
|
SparseTensorRewriting.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
|
||||||
|
|
|
@ -49,13 +49,17 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto *ctx = &getContext();
|
auto *ctx = &getContext();
|
||||||
RewritePatternSet patterns(ctx);
|
// Apply pre-rewriting.
|
||||||
|
RewritePatternSet prePatterns(ctx);
|
||||||
|
populateSparseTensorRewriting(prePatterns);
|
||||||
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
|
||||||
// Translate strategy flags to strategy options.
|
// Translate strategy flags to strategy options.
|
||||||
SparsificationOptions options(
|
SparsificationOptions options(
|
||||||
sparseParallelizationStrategy(parallelization),
|
sparseParallelizationStrategy(parallelization),
|
||||||
sparseVectorizationStrategy(vectorization), vectorLength,
|
sparseVectorizationStrategy(vectorization), vectorLength,
|
||||||
enableSIMDIndex32, enableVLAVectorization);
|
enableSIMDIndex32, enableVLAVectorization);
|
||||||
// Apply rewriting.
|
// Apply sparsification and vector cleanup rewriting.
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
populateSparsificationPatterns(patterns, options);
|
populateSparsificationPatterns(patterns, options);
|
||||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||||
|
|
|
@ -6,20 +6,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file implements linalg dialect rewriting specific to sparse tensors.
|
// This file implements rewriting rules that are specific to sparse tensors.
|
||||||
//
|
|
||||||
// Sparsity should be mostly transparent to the linalg dialect optimizations
|
|
||||||
// (i.e., the dense and sparse take the same path). However, in some cases,
|
|
||||||
// optimizations only make sense in the context of sparse tensors. This file
|
|
||||||
// implements such sparsity specific rewriting rules.
|
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
||||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||||
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
@ -98,6 +94,7 @@ static bool isSumOfMul(GenericOp op) {
|
||||||
//===---------------------------------------------------------------------===//
|
//===---------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// Rewriting rule that converts two kernels:
|
/// Rewriting rule that converts two kernels:
|
||||||
///
|
///
|
||||||
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
|
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
|
||||||
|
@ -114,6 +111,7 @@ namespace {
|
||||||
/// a fusion may actually reduce the asymptotic complexity of the kernel,
|
/// a fusion may actually reduce the asymptotic complexity of the kernel,
|
||||||
/// since intermediate results may be nullified.
|
/// since intermediate results may be nullified.
|
||||||
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
|
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
|
||||||
|
public:
|
||||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(GenericOp op,
|
LogicalResult matchAndRewrite(GenericOp op,
|
||||||
|
@ -194,13 +192,55 @@ private:
|
||||||
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
|
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Sparse rewriting rule for reshape operator.
|
||||||
|
template <typename ReshapeOp>
|
||||||
|
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ReshapeOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
||||||
|
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
||||||
|
// Since a pure dense expansion is very cheap (change of view), for
|
||||||
|
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
|
||||||
|
// conversion from the reshape operation itself.
|
||||||
|
// All other cases are handled elsewhere.
|
||||||
|
if (encDst && encSrc) {
|
||||||
|
return failure();
|
||||||
|
} else if (encSrc) {
|
||||||
|
RankedTensorType rtp =
|
||||||
|
op.getSrc().getType().template cast<RankedTensorType>();
|
||||||
|
auto denseTp =
|
||||||
|
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||||
|
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||||
|
op->setOperand(0, convert);
|
||||||
|
return success();
|
||||||
|
} else if (encDst) {
|
||||||
|
RankedTensorType rtp =
|
||||||
|
op.getResult().getType().template cast<RankedTensorType>();
|
||||||
|
auto denseTp =
|
||||||
|
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||||
|
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
|
||||||
|
op.getReassociation());
|
||||||
|
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
||||||
|
rewriter.replaceOp(op, convert);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//===---------------------------------------------------------------------===//
|
//===---------------------------------------------------------------------===//
|
||||||
// Methods that add patterns described in this file to a pattern list.
|
// Methods that add patterns described in this file to a pattern list.
|
||||||
//===---------------------------------------------------------------------===//
|
//===---------------------------------------------------------------------===//
|
||||||
|
|
||||||
void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
|
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
|
||||||
auto *context = patterns.getContext();
|
// TODO(springerm): enable FuseSparseMultiplyOverAdd
|
||||||
patterns.add<FuseSparseMultiplyOverAdd>(context);
|
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||||
|
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
|
||||||
}
|
}
|
|
@ -1802,46 +1802,6 @@ private:
|
||||||
SparsificationOptions options;
|
SparsificationOptions options;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Sparse rewriting rule for reshape operator.
|
|
||||||
template <typename ReshapeOp>
|
|
||||||
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ReshapeOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
|
||||||
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
|
||||||
// Since a pure dense expansion is very cheap (change of view), for
|
|
||||||
// a sparse2dense or dense2sparse, we can simply unfuse a sparse
|
|
||||||
// conversion from the reshape operation itself.
|
|
||||||
// All other cases are handled elsewhere.
|
|
||||||
if (encDst && encSrc) {
|
|
||||||
return failure();
|
|
||||||
} else if (encSrc) {
|
|
||||||
RankedTensorType rtp =
|
|
||||||
op.getSrc().getType().template cast<RankedTensorType>();
|
|
||||||
auto denseTp =
|
|
||||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
|
||||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
|
||||||
op->setOperand(0, convert);
|
|
||||||
return success();
|
|
||||||
} else if (encDst) {
|
|
||||||
RankedTensorType rtp =
|
|
||||||
op.getResult().getType().template cast<RankedTensorType>();
|
|
||||||
auto denseTp =
|
|
||||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
|
||||||
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
|
|
||||||
op.getReassociation());
|
|
||||||
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
|
||||||
rewriter.replaceOp(op, convert);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/// Populates the given patterns list with rewriting rules required for
|
/// Populates the given patterns list with rewriting rules required for
|
||||||
|
@ -1849,6 +1809,4 @@ public:
|
||||||
void mlir::populateSparsificationPatterns(
|
void mlir::populateSparsificationPatterns(
|
||||||
RewritePatternSet &patterns, const SparsificationOptions &options) {
|
RewritePatternSet &patterns, const SparsificationOptions &options) {
|
||||||
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
|
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
|
||||||
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
|
|
||||||
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2115,6 +2115,7 @@ cc_library(
|
||||||
":SparseTensorDialect",
|
":SparseTensorDialect",
|
||||||
":SparseTensorPassIncGen",
|
":SparseTensorPassIncGen",
|
||||||
":SparseTensorUtils",
|
":SparseTensorUtils",
|
||||||
|
":Support",
|
||||||
":TensorDialect",
|
":TensorDialect",
|
||||||
":Transforms",
|
":Transforms",
|
||||||
":VectorDialect",
|
":VectorDialect",
|
||||||
|
|
Loading…
Reference in New Issue