forked from OSchip/llvm-project
[mlir][sparse] move sparse tensor rewriting into its own pass
Makes individual testing and debugging easier. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D135319
This commit is contained in:
parent
617ca92bf1
commit
779dcd2ecc
|
@ -163,6 +163,10 @@ std::unique_ptr<Pass> createSparseTensorCodegenPass();
|
||||||
|
|
||||||
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
|
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createSparseTensorRewritePass();
|
||||||
|
std::unique_ptr<Pass>
|
||||||
|
createSparseTensorRewritePass(const SparsificationOptions &options);
|
||||||
|
|
||||||
std::unique_ptr<Pass> createDenseBufferizationPass(
|
std::unique_ptr<Pass> createDenseBufferizationPass(
|
||||||
const bufferization::OneShotBufferizationOptions &options);
|
const bufferization::OneShotBufferizationOptions &options);
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,27 @@
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
|
||||||
|
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
|
||||||
|
let description = [{
|
||||||
|
A pass that applies rewriting rules to sparse tensor operations prior
|
||||||
|
to running the actual sparsification pass.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::createSparseTensorRewritePass()";
|
||||||
|
let dependentDialects = [
|
||||||
|
"arith::ArithDialect",
|
||||||
|
"bufferization::BufferizationDialect",
|
||||||
|
"linalg::LinalgDialect",
|
||||||
|
"memref::MemRefDialect",
|
||||||
|
"scf::SCFDialect",
|
||||||
|
"sparse_tensor::SparseTensorDialect",
|
||||||
|
];
|
||||||
|
let options = [
|
||||||
|
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
|
||||||
|
"true", "Enable runtime library for manipulating sparse tensors">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
|
def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
|
||||||
let summary = "Automatically generate sparse tensor code from sparse tensor types";
|
let summary = "Automatically generate sparse tensor code from sparse tensor types";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -57,6 +78,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
|
||||||
"arith::ArithDialect",
|
"arith::ArithDialect",
|
||||||
"bufferization::BufferizationDialect",
|
"bufferization::BufferizationDialect",
|
||||||
"LLVM::LLVMDialect",
|
"LLVM::LLVMDialect",
|
||||||
|
"linalg::LinalgDialect",
|
||||||
"memref::MemRefDialect",
|
"memref::MemRefDialect",
|
||||||
"scf::SCFDialect",
|
"scf::SCFDialect",
|
||||||
"sparse_tensor::SparseTensorDialect",
|
"sparse_tensor::SparseTensorDialect",
|
||||||
|
@ -193,4 +215,5 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
|
||||||
"sparse_tensor::SparseTensorDialect",
|
"sparse_tensor::SparseTensorDialect",
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
|
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
|
||||||
|
|
|
@ -58,6 +58,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
||||||
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
||||||
if (options.testBufferizationAnalysisOnly)
|
if (options.testBufferizationAnalysisOnly)
|
||||||
return;
|
return;
|
||||||
|
pm.addPass(createSparseTensorRewritePass(options.sparsificationOptions()));
|
||||||
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
|
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
|
||||||
if (options.enableRuntimeLibrary)
|
if (options.enableRuntimeLibrary)
|
||||||
pm.addPass(createSparseTensorConversionPass(
|
pm.addPass(createSparseTensorConversionPass(
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
#define GEN_PASS_DEF_SPARSETENSORREWRITE
|
||||||
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
||||||
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
||||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||||
|
@ -37,6 +38,23 @@ namespace {
|
||||||
// Passes implementation.
|
// Passes implementation.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct SparseTensorRewritePass
|
||||||
|
: public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
|
||||||
|
|
||||||
|
SparseTensorRewritePass() = default;
|
||||||
|
SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
|
||||||
|
SparseTensorRewritePass(const SparsificationOptions &options) {
|
||||||
|
enableRuntimeLibrary = options.enableRuntimeLibrary;
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto *ctx = &getContext();
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
|
||||||
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct SparsificationPass
|
struct SparsificationPass
|
||||||
: public impl::SparsificationPassBase<SparsificationPass> {
|
: public impl::SparsificationPassBase<SparsificationPass> {
|
||||||
|
|
||||||
|
@ -53,14 +71,10 @@ struct SparsificationPass
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto *ctx = &getContext();
|
auto *ctx = &getContext();
|
||||||
RewritePatternSet prePatterns(ctx);
|
|
||||||
// Translate strategy flags to strategy options.
|
// Translate strategy flags to strategy options.
|
||||||
SparsificationOptions options(parallelization, vectorization, vectorLength,
|
SparsificationOptions options(parallelization, vectorization, vectorLength,
|
||||||
enableSIMDIndex32, enableVLAVectorization,
|
enableSIMDIndex32, enableVLAVectorization,
|
||||||
enableRuntimeLibrary);
|
enableRuntimeLibrary);
|
||||||
// Apply pre-rewriting.
|
|
||||||
populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary);
|
|
||||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
|
|
||||||
// Apply sparsification and vector cleanup rewriting.
|
// Apply sparsification and vector cleanup rewriting.
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
populateSparsificationPatterns(patterns, options);
|
populateSparsificationPatterns(patterns, options);
|
||||||
|
@ -236,6 +250,15 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
|
||||||
// Pass creation methods.
|
// Pass creation methods.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
|
||||||
|
return std::make_unique<SparseTensorRewritePass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Pass>
|
||||||
|
mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
|
||||||
|
return std::make_unique<SparseTensorRewritePass>(options);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
||||||
return std::make_unique<SparsificationPass>();
|
return std::make_unique<SparsificationPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s -sparsification | FileCheck %s
|
// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
|
||||||
|
|
||||||
#SparseVector = #sparse_tensor.encoding<{
|
#SparseVector = #sparse_tensor.encoding<{
|
||||||
dimLevelType = ["compressed"]
|
dimLevelType = ["compressed"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s --sparsification=enable-runtime-library=false | FileCheck %s
|
// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s
|
||||||
|
|
||||||
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
|
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||||
|
|
||||||
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s --tensor-copy-insertion --sparsification --cse | FileCheck %s
|
// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
|
||||||
|
|
||||||
#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue