[mlir][sparse] move sparse tensor rewriting into its own pass

Makes individual testing and debugging easier.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D135319
This commit is contained in:
Aart Bik 2022-10-05 13:38:51 -07:00
parent 617ca92bf1
commit 779dcd2ecc
8 changed files with 59 additions and 8 deletions

View File

@ -163,6 +163,10 @@ std::unique_ptr<Pass> createSparseTensorCodegenPass();
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
std::unique_ptr<Pass> createSparseTensorRewritePass();
std::unique_ptr<Pass>
createSparseTensorRewritePass(const SparsificationOptions &options);
std::unique_ptr<Pass> createDenseBufferizationPass( std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options); const bufferization::OneShotBufferizationOptions &options);

View File

@ -11,6 +11,27 @@
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
let description = [{
A pass that applies rewriting rules to sparse tensor operations prior
to running the actual sparsification pass.
}];
let constructor = "mlir::createSparseTensorRewritePass()";
let dependentDialects = [
"arith::ArithDialect",
"bufferization::BufferizationDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
let options = [
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
"true", "Enable runtime library for manipulating sparse tensors">
];
}
def SparsificationPass : Pass<"sparsification", "ModuleOp"> { def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
let summary = "Automatically generate sparse tensor code from sparse tensor types"; let summary = "Automatically generate sparse tensor code from sparse tensor types";
let description = [{ let description = [{
@ -57,6 +78,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
"arith::ArithDialect", "arith::ArithDialect",
"bufferization::BufferizationDialect", "bufferization::BufferizationDialect",
"LLVM::LLVMDialect", "LLVM::LLVMDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect", "memref::MemRefDialect",
"scf::SCFDialect", "scf::SCFDialect",
"sparse_tensor::SparseTensorDialect", "sparse_tensor::SparseTensorDialect",
@ -193,4 +215,5 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
"sparse_tensor::SparseTensorDialect", "sparse_tensor::SparseTensorDialect",
]; ];
} }
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

View File

@ -58,6 +58,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
/*analysisOnly=*/options.testBufferizationAnalysisOnly))); /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
if (options.testBufferizationAnalysisOnly) if (options.testBufferizationAnalysisOnly)
return; return;
pm.addPass(createSparseTensorRewritePass(options.sparsificationOptions()));
pm.addPass(createSparsificationPass(options.sparsificationOptions())); pm.addPass(createSparsificationPass(options.sparsificationOptions()));
if (options.enableRuntimeLibrary) if (options.enableRuntimeLibrary)
pm.addPass(createSparseTensorConversionPass( pm.addPass(createSparseTensorConversionPass(

View File

@ -21,6 +21,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir { namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS #define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSETENSORCODEGEN
@ -37,6 +38,23 @@ namespace {
// Passes implementation. // Passes implementation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
struct SparseTensorRewritePass
: public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
SparseTensorRewritePass() = default;
SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
SparseTensorRewritePass(const SparsificationOptions &options) {
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct SparsificationPass struct SparsificationPass
: public impl::SparsificationPassBase<SparsificationPass> { : public impl::SparsificationPassBase<SparsificationPass> {
@ -53,14 +71,10 @@ struct SparsificationPass
void runOnOperation() override { void runOnOperation() override {
auto *ctx = &getContext(); auto *ctx = &getContext();
RewritePatternSet prePatterns(ctx);
// Translate strategy flags to strategy options. // Translate strategy flags to strategy options.
SparsificationOptions options(parallelization, vectorization, vectorLength, SparsificationOptions options(parallelization, vectorization, vectorLength,
enableSIMDIndex32, enableVLAVectorization, enableSIMDIndex32, enableVLAVectorization,
enableRuntimeLibrary); enableRuntimeLibrary);
// Apply pre-rewriting.
populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
// Apply sparsification and vector cleanup rewriting. // Apply sparsification and vector cleanup rewriting.
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options); populateSparsificationPatterns(patterns, options);
@ -236,6 +250,15 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
// Pass creation methods. // Pass creation methods.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
return std::make_unique<SparseTensorRewritePass>();
}
std::unique_ptr<Pass>
mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
return std::make_unique<SparseTensorRewritePass>(options);
}
std::unique_ptr<Pass> mlir::createSparsificationPass() { std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>(); return std::make_unique<SparsificationPass>();
} }

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -sparsification | FileCheck %s // RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ #SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"] dimLevelType = ["compressed"]

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --sparsification=enable-runtime-library=false | FileCheck %s // RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s // RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --tensor-copy-insertion --sparsification --cse | FileCheck %s // RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>