Migrate MLIR test passes to the new registration API

Make sure they all define getArgument()/getDescription().

Depends On D104421

Differential Revision: https://reviews.llvm.org/D104426
This commit is contained in:
Mehdi Amini 2021-06-16 23:42:13 +00:00
parent c8a3f561eb
commit b5e22e6d42
64 changed files with 578 additions and 323 deletions

View File

@ -47,6 +47,11 @@ class SerializeToCubinPass
public:
SerializeToCubinPass();
StringRef getArgument() const override { return "gpu-to-cubin"; }
StringRef getDescription() const override {
return "Lower GPU kernel function to CUBIN binary annotations";
}
private:
void getDependentDialects(DialectRegistry &registry) const override;
@ -126,7 +131,6 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
void mlir::registerGpuSerializeToCubinPass() {
PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
"gpu-to-cubin", "Lower GPU kernel function to CUBIN binary annotations",
[] {
// Initialize LLVM NVPTX backend.
LLVMInitializeNVPTXTarget();

View File

@ -50,6 +50,11 @@ class SerializeToHsacoPass
public:
SerializeToHsacoPass();
StringRef getArgument() const override { return "gpu-to-hsaco"; }
StringRef getDescription() const override {
return "Lower GPU kernel function to HSACO binary annotations";
}
private:
void getDependentDialects(DialectRegistry &registry) const override;
@ -268,7 +273,6 @@ SerializeToHsacoPass::serializeISA(const std::string &isa) {
// Register pass to serialize GPU kernel functions to a HSACO binary annotation.
void mlir::registerGpuSerializeToHsacoPass() {
PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
"gpu-to-hsaco", "Lower GPU kernel function to HSACO binary annotations",
[] {
// Initialize LLVM AMDGPU backend.
LLVMInitializeAMDGPUAsmParser();

View File

@ -46,6 +46,10 @@ static void printAliasOperand(Value value) {
namespace {
struct TestAliasAnalysisPass
: public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
StringRef getArgument() const final { return "test-alias-analysis"; }
StringRef getDescription() const final {
return "Test alias analysis results.";
}
void runOnOperation() override {
llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
@ -84,6 +88,10 @@ struct TestAliasAnalysisPass
namespace {
struct TestAliasAnalysisModRefPass
: public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
StringRef getArgument() const final { return "test-alias-analysis-modref"; }
StringRef getDescription() const final {
return "Test alias analysis ModRef results.";
}
void runOnOperation() override {
llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
@ -126,10 +134,8 @@ struct TestAliasAnalysisModRefPass
namespace mlir {
namespace test {
void registerTestAliasAnalysisPass() {
PassRegistration<TestAliasAnalysisPass> aliasPass(
"test-alias-analysis", "Test alias analysis results.");
PassRegistration<TestAliasAnalysisModRefPass> modRefPass(
"test-alias-analysis-modref", "Test alias analysis ModRef results.");
PassRegistration<TestAliasAnalysisPass>();
PassRegistration<TestAliasAnalysisModRefPass>();
}
} // namespace test
} // namespace mlir

View File

@ -19,6 +19,10 @@ using namespace mlir;
namespace {
struct TestCallGraphPass
: public PassWrapper<TestCallGraphPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-print-callgraph"; }
StringRef getDescription() const final {
return "Print the contents of a constructed callgraph.";
}
void runOnOperation() override {
llvm::errs() << "Testing : " << getOperation()->getAttr("test.name")
<< "\n";
@ -29,9 +33,6 @@ struct TestCallGraphPass
namespace mlir {
namespace test {
void registerTestCallGraphPass() {
PassRegistration<TestCallGraphPass> pass(
"test-print-callgraph", "Print the contents of a constructed callgraph.");
}
void registerTestCallGraphPass() { PassRegistration<TestCallGraphPass>(); }
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,10 @@ using namespace mlir;
namespace {
struct TestLivenessPass : public PassWrapper<TestLivenessPass, FunctionPass> {
StringRef getArgument() const final { return "test-print-liveness"; }
StringRef getDescription() const final {
return "Print the contents of a constructed liveness information.";
}
void runOnFunction() override {
llvm::errs() << "Testing : " << getFunction().getName() << "\n";
getAnalysis<Liveness>().print(llvm::errs());
@ -30,10 +34,6 @@ struct TestLivenessPass : public PassWrapper<TestLivenessPass, FunctionPass> {
namespace mlir {
namespace test {
void registerTestLivenessPass() {
PassRegistration<TestLivenessPass>(
"test-print-liveness",
"Print the contents of a constructed liveness information.");
}
void registerTestLivenessPass() { PassRegistration<TestLivenessPass>(); }
} // namespace test
} // namespace mlir

View File

@ -29,6 +29,10 @@ namespace {
/// Checks for out of bound memref access subscripts..
struct TestMemRefBoundCheck
: public PassWrapper<TestMemRefBoundCheck, FunctionPass> {
StringRef getArgument() const final { return "test-memref-bound-check"; }
StringRef getDescription() const final {
return "Check memref access bounds in a Function";
}
void runOnFunction() override;
};
@ -46,9 +50,6 @@ void TestMemRefBoundCheck::runOnFunction() {
namespace mlir {
namespace test {
void registerMemRefBoundCheck() {
PassRegistration<TestMemRefBoundCheck>(
"test-memref-bound-check", "Check memref access bounds in a Function");
}
void registerMemRefBoundCheck() { PassRegistration<TestMemRefBoundCheck>(); }
} // namespace test
} // namespace mlir

View File

@ -28,6 +28,10 @@ namespace {
/// Checks dependences between all pairs of memref accesses in a Function.
struct TestMemRefDependenceCheck
: public PassWrapper<TestMemRefDependenceCheck, FunctionPass> {
StringRef getArgument() const final { return "test-memref-dependence-check"; }
StringRef getDescription() const final {
return "Checks dependences between all pairs of memref accesses.";
}
SmallVector<Operation *, 4> loadsAndStores;
void runOnFunction() override;
};
@ -112,9 +116,7 @@ void TestMemRefDependenceCheck::runOnFunction() {
namespace mlir {
namespace test {
void registerTestMemRefDependenceCheck() {
PassRegistration<TestMemRefDependenceCheck> pass(
"test-memref-dependence-check",
"Checks dependences between all pairs of memref accesses.");
PassRegistration<TestMemRefDependenceCheck>();
}
} // namespace test
} // namespace mlir

View File

@ -15,6 +15,12 @@ using namespace mlir;
namespace {
struct TestMemRefStrideCalculation
: public PassWrapper<TestMemRefStrideCalculation, FunctionPass> {
StringRef getArgument() const final {
return "test-memref-stride-calculation";
}
StringRef getDescription() const final {
return "Test operation constant folding";
}
void runOnFunction() override;
};
} // end anonymous namespace
@ -51,8 +57,7 @@ void TestMemRefStrideCalculation::runOnFunction() {
namespace mlir {
namespace test {
void registerTestMemRefStrideCalculation() {
PassRegistration<TestMemRefStrideCalculation> pass(
"test-memref-stride-calculation", "Test operation constant folding");
PassRegistration<TestMemRefStrideCalculation>();
}
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,14 @@ namespace {
struct TestNumberOfBlockExecutionsPass
: public PassWrapper<TestNumberOfBlockExecutionsPass, FunctionPass> {
StringRef getArgument() const final {
return "test-print-number-of-block-executions";
}
StringRef getDescription() const final {
return "Print the contents of a constructed number of executions analysis "
"for "
"all blocks.";
}
void runOnFunction() override {
llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
getAnalysis<NumberOfExecutions>().printBlockExecutions(
@ -29,6 +37,14 @@ struct TestNumberOfBlockExecutionsPass
struct TestNumberOfOperationExecutionsPass
: public PassWrapper<TestNumberOfOperationExecutionsPass, FunctionPass> {
StringRef getArgument() const final {
return "test-print-number-of-operation-executions";
}
StringRef getDescription() const final {
return "Print the contents of a constructed number of executions analysis "
"for "
"all operations.";
}
void runOnFunction() override {
llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
getAnalysis<NumberOfExecutions>().printOperationExecutions(
@ -41,17 +57,11 @@ struct TestNumberOfOperationExecutionsPass
namespace mlir {
namespace test {
void registerTestNumberOfBlockExecutionsPass() {
PassRegistration<TestNumberOfBlockExecutionsPass>(
"test-print-number-of-block-executions",
"Print the contents of a constructed number of executions analysis for "
"all blocks.");
PassRegistration<TestNumberOfBlockExecutionsPass>();
}
void registerTestNumberOfOperationExecutionsPass() {
PassRegistration<TestNumberOfOperationExecutionsPass>(
"test-print-number-of-operation-executions",
"Print the contents of a constructed number of executions analysis for "
"all operations.");
PassRegistration<TestNumberOfOperationExecutionsPass>();
}
} // namespace test
} // namespace mlir

View File

@ -38,6 +38,11 @@ public:
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<LLVM::LLVMDialect>();
}
StringRef getArgument() const final { return "test-convert-call-op"; }
StringRef getDescription() const final {
return "Tests conversion of `std.call` to `llvm.call` in "
"presence of custom types";
}
void runOnOperation() override {
ModuleOp m = getOperation();
@ -68,11 +73,6 @@ public:
namespace mlir {
namespace test {
void registerConvertCallOpPass() {
PassRegistration<TestConvertCallOp>(
"test-convert-call-op",
"Tests conversion of `std.call` to `llvm.call` in "
"presence of custom types");
}
void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
} // namespace test
} // namespace mlir

View File

@ -29,6 +29,10 @@ namespace {
struct TestAffineDataCopy
: public PassWrapper<TestAffineDataCopy, FunctionPass> {
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests affine data copy utility functions.";
}
TestAffineDataCopy() = default;
TestAffineDataCopy(const TestAffineDataCopy &pass){};
@ -128,7 +132,6 @@ void TestAffineDataCopy::runOnFunction() {
namespace mlir {
void registerTestAffineDataCopyPass() {
PassRegistration<TestAffineDataCopy>(
PASS_NAME, "Tests affine data copy utility functions.");
PassRegistration<TestAffineDataCopy>();
}
} // namespace mlir

View File

@ -22,6 +22,10 @@ using namespace mlir;
namespace {
struct TestAffineLoopParametricTiling
: public PassWrapper<TestAffineLoopParametricTiling, FunctionPass> {
StringRef getArgument() const final { return "test-affine-parametric-tile"; }
StringRef getDescription() const final {
return "Tile affine loops using SSA values as tile sizes";
}
void runOnFunction() override;
};
} // end anonymous namespace
@ -83,9 +87,7 @@ void TestAffineLoopParametricTiling::runOnFunction() {
namespace mlir {
namespace test {
void registerTestAffineLoopParametricTilingPass() {
PassRegistration<TestAffineLoopParametricTiling>(
"test-affine-parametric-tile",
"Tile affine loops using SSA values as tile sizes");
PassRegistration<TestAffineLoopParametricTiling>();
}
} // namespace test
} // namespace mlir

View File

@ -25,6 +25,10 @@ namespace {
/// This pass applies the permutation on the first maximal perfect nest.
struct TestAffineLoopUnswitching
: public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests affine loop unswitching / if/else hoisting";
}
TestAffineLoopUnswitching() = default;
TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
@ -54,7 +58,6 @@ void TestAffineLoopUnswitching::runOnFunction() {
namespace mlir {
void registerTestAffineLoopUnswitchingPass() {
PassRegistration<TestAffineLoopUnswitching>(
PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
PassRegistration<TestAffineLoopUnswitching>();
}
} // namespace mlir

View File

@ -27,6 +27,10 @@ namespace {
/// This pass applies the permutation on the first maximal perfect nest.
struct TestLoopPermutation
: public PassWrapper<TestLoopPermutation, FunctionPass> {
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests affine loop permutation utility";
}
TestLoopPermutation() = default;
TestLoopPermutation(const TestLoopPermutation &pass){};
@ -62,7 +66,6 @@ void TestLoopPermutation::runOnFunction() {
namespace mlir {
void registerTestLoopPermutationPass() {
PassRegistration<TestLoopPermutation>(
PASS_NAME, "Tests affine loop permutation utility");
PassRegistration<TestLoopPermutation>();
}
} // namespace mlir

View File

@ -75,6 +75,10 @@ struct VectorizerTestPass
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
StringRef getArgument() const final { return "affine-super-vectorizer-test"; }
StringRef getDescription() const final {
return "Tests vectorizer standalone functionality.";
}
void runOnFunction() override;
void testVectorShapeRatio(llvm::raw_ostream &outs);
@ -269,9 +273,5 @@ void VectorizerTestPass::runOnFunction() {
}
namespace mlir {
void registerVectorizerTestPass() {
PassRegistration<VectorizerTestPass> pass(
"affine-super-vectorizer-test",
"Tests vectorizer standalone functionality.");
}
void registerVectorizerTestPass() { PassRegistration<VectorizerTestPass>(); }
} // namespace mlir

View File

@ -21,6 +21,8 @@ namespace {
/// result types.
struct TestDataLayoutQuery
: public PassWrapper<TestDataLayoutQuery, FunctionPass> {
StringRef getArgument() const final { return "test-data-layout-query"; }
StringRef getDescription() const final { return "Test data layout queries"; }
void runOnFunction() override {
FuncOp func = getFunction();
Builder builder(func.getContext());
@ -48,9 +50,6 @@ struct TestDataLayoutQuery
namespace mlir {
namespace test {
void registerTestDataLayoutQuery() {
PassRegistration<TestDataLayoutQuery>("test-data-layout-query",
"Test data layout queries");
}
void registerTestDataLayoutQuery() { PassRegistration<TestDataLayoutQuery>(); }
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,10 @@ namespace {
class TestSerializeToCubinPass
: public PassWrapper<TestSerializeToCubinPass, gpu::SerializeToBlobPass> {
public:
StringRef getArgument() const final { return "test-gpu-to-cubin"; }
StringRef getDescription() const final {
return "Lower GPU kernel function to CUBIN binary annotations";
}
TestSerializeToCubinPass();
private:
@ -53,17 +57,15 @@ namespace mlir {
namespace test {
// Register test pass to serialize GPU module to a CUBIN binary annotation.
void registerTestGpuSerializeToCubinPass() {
PassRegistration<TestSerializeToCubinPass> registerSerializeToCubin(
"test-gpu-to-cubin",
"Lower GPU kernel function to CUBIN binary annotations", [] {
// Initialize LLVM NVPTX backend.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
PassRegistration<TestSerializeToCubinPass>([] {
// Initialize LLVM NVPTX backend.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
return std::make_unique<TestSerializeToCubinPass>();
});
return std::make_unique<TestSerializeToCubinPass>();
});
}
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,10 @@ namespace {
class TestSerializeToHsacoPass
: public PassWrapper<TestSerializeToHsacoPass, gpu::SerializeToBlobPass> {
public:
StringRef getArgument() const final { return "test-gpu-to-hsaco"; }
StringRef getDescription() const final {
return "Lower GPU kernel function to HSAco binary annotations";
}
TestSerializeToHsacoPass();
private:
@ -52,17 +56,15 @@ namespace mlir {
namespace test {
// Register test pass to serialize GPU module to a HSAco binary annotation.
void registerTestGpuSerializeToHsacoPass() {
PassRegistration<TestSerializeToHsacoPass> registerSerializeToHsaco(
"test-gpu-to-hsaco",
"Lower GPU kernel function to HSAco binary annotations", [] {
// Initialize LLVM AMDGPU backend.
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
PassRegistration<TestSerializeToHsacoPass>([] {
// Initialize LLVM AMDGPU backend.
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
return std::make_unique<TestSerializeToHsacoPass>();
});
return std::make_unique<TestSerializeToHsacoPass>();
});
}
} // namespace test
} // namespace mlir

View File

@ -35,6 +35,10 @@ class TestGpuMemoryPromotionPass
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect,
scf::SCFDialect>();
}
StringRef getArgument() const final { return "test-gpu-memory-promotion"; }
StringRef getDescription() const final {
return "Promotes the annotated arguments of gpu.func to workgroup memory.";
}
void runOnOperation() override {
gpu::GPUFuncOp op = getOperation();
@ -48,8 +52,6 @@ class TestGpuMemoryPromotionPass
namespace mlir {
void registerTestGpuMemoryPromotionPass() {
PassRegistration<TestGpuMemoryPromotionPass>(
"test-gpu-memory-promotion",
"Promotes the annotated arguments of gpu.func to workgroup memory.");
PassRegistration<TestGpuMemoryPromotionPass>();
}
} // namespace mlir

View File

@ -22,6 +22,12 @@ namespace {
class TestGpuGreedyParallelLoopMappingPass
: public PassWrapper<TestGpuGreedyParallelLoopMappingPass,
OperationPass<FuncOp>> {
StringRef getArgument() const final {
return "test-gpu-greedy-parallel-loop-mapping";
}
StringRef getDescription() const final {
return "Greedily maps all parallel loops to gpu hardware ids.";
}
void runOnOperation() override {
Operation *op = getOperation();
for (Region &region : op->getRegions())
@ -33,9 +39,7 @@ class TestGpuGreedyParallelLoopMappingPass
namespace mlir {
namespace test {
void registerTestGpuParallelLoopMappingPass() {
PassRegistration<TestGpuGreedyParallelLoopMappingPass> registration(
"test-gpu-greedy-parallel-loop-mapping",
"Greedily maps all parallel loops to gpu hardware ids.");
PassRegistration<TestGpuGreedyParallelLoopMappingPass>();
}
} // namespace test
} // namespace mlir

View File

@ -24,6 +24,10 @@ struct TestGpuRewritePass
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect, memref::MemRefDialect>();
}
StringRef getArgument() const final { return "test-gpu-rewrite"; }
StringRef getDescription() const final {
return "Applies all rewrite patterns within the GPU dialect.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
@ -34,8 +38,6 @@ struct TestGpuRewritePass
namespace mlir {
void registerTestAllReduceLoweringPass() {
PassRegistration<TestGpuRewritePass> pass(
"test-gpu-rewrite",
"Applies all rewrite patterns within the GPU dialect.");
PassRegistration<TestGpuRewritePass>();
}
} // namespace mlir

View File

@ -26,6 +26,10 @@ namespace {
class TestConvVectorization
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
public:
StringRef getArgument() const final { return "test-conv-vectorization"; }
StringRef getDescription() const final {
return "Test vectorization of convolutions";
}
TestConvVectorization() = default;
TestConvVectorization(const TestConvVectorization &) {}
explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
@ -129,8 +133,7 @@ void TestConvVectorization::runOnOperation() {
namespace mlir {
namespace test {
void registerTestConvVectorization() {
PassRegistration<TestConvVectorization> testTransformPatternsPass(
"test-conv-vectorization", "Test vectorization of convolutions");
PassRegistration<TestConvVectorization>();
}
} // namespace test
} // namespace mlir

View File

@ -28,6 +28,10 @@ using namespace mlir::linalg;
namespace {
struct TestLinalgCodegenStrategy
: public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
StringRef getArgument() const final { return "test-linalg-codegen-strategy"; }
StringRef getDescription() const final {
return "Test Linalg Codegen Strategy.";
}
TestLinalgCodegenStrategy() = default;
TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
@ -227,8 +231,7 @@ void TestLinalgCodegenStrategy::runOnFunction() {
namespace mlir {
namespace test {
void registerTestLinalgCodegenStrategy() {
PassRegistration<TestLinalgCodegenStrategy> testLinalgCodegenStrategyPass(
"test-linalg-codegen-strategy", "Test Linalg Codegen Strategy.");
PassRegistration<TestLinalgCodegenStrategy>();
}
} // namespace test
} // namespace mlir

View File

@ -40,6 +40,8 @@ static LinalgLoopDistributionOptions getDistributionOptions() {
namespace {
struct TestLinalgDistribution
: public PassWrapper<TestLinalgDistribution, FunctionPass> {
StringRef getArgument() const final { return "test-linalg-distribution"; }
StringRef getDescription() const final { return "Test Linalg distribution."; }
TestLinalgDistribution() = default;
TestLinalgDistribution(const TestLinalgDistribution &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
@ -72,8 +74,7 @@ void TestLinalgDistribution::runOnFunction() {
namespace mlir {
namespace test {
void registerTestLinalgDistribution() {
PassRegistration<TestLinalgDistribution> testTestLinalgDistributionPass(
"test-linalg-distribution", "Test Linalg distribution.");
PassRegistration<TestLinalgDistribution>();
}
} // namespace test
} // namespace mlir

View File

@ -51,6 +51,12 @@ struct TestLinalgElementwiseFusion
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-elementwise-fusion-patterns";
}
StringRef getDescription() const final {
return "Test Linalg element wise operation fusion patterns";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
@ -73,6 +79,10 @@ struct TestPushExpandingReshape
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
StringRef getDescription() const final {
return "Test Linalg reshape push patterns";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
@ -86,14 +96,11 @@ struct TestPushExpandingReshape
namespace test {
void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass(
"test-linalg-elementwise-fusion-patterns",
"Test Linalg element wise operation fusion patterns");
PassRegistration<TestLinalgElementwiseFusion>();
}
void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
"test-linalg-push-reshape", "Test Linalg reshape push patterns");
PassRegistration<TestPushExpandingReshape>();
}
} // namespace test

View File

@ -108,16 +108,15 @@ static void fillFusionPatterns(MLIRContext *context,
}
namespace {
template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops>
template <LinalgTilingLoopType LoopType>
struct TestLinalgFusionTransforms
: public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
TestLinalgFusionTransforms() = default;
TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect, StandardOpsDialect>();
}
TestLinalgFusionTransforms() = default;
TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
@ -130,6 +129,39 @@ struct TestLinalgFusionTransforms
(void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
}
};
struct TestLinalgFusionTransformsParallelLoops
: public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
StringRef getArgument() const final {
return "test-linalg-fusion-transform-patterns";
}
StringRef getDescription() const final {
return "Test Linalg fusion transformation patterns by applying them "
"greedily.";
}
};
struct TestLinalgFusionTransformsLoops
: public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
StringRef getArgument() const final {
return "test-linalg-tensor-fusion-transform-patterns";
}
StringRef getDescription() const final {
return "Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.";
}
};
struct TestLinalgFusionTransformsTiledLoops
: public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
StringRef getArgument() const final {
return "test-linalg-tiled-loop-fusion-transform-patterns";
}
StringRef getDescription() const final {
return "Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.";
}
};
} // namespace
static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
@ -195,6 +227,10 @@ struct TestLinalgGreedyFusion
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
StringRef getDescription() const final {
return "Test Linalg fusion by applying a greedy test transformation.";
}
void runOnFunction() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns =
@ -218,6 +254,10 @@ struct TestLinalgGreedyFusion
/// testing.
struct TestLinalgTileAndFuseSequencePass
: public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
StringRef getDescription() const final {
return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
}
TestLinalgTileAndFuseSequencePass() = default;
TestLinalgTileAndFuseSequencePass(
const TestLinalgTileAndFuseSequencePass &pass){};
@ -261,39 +301,25 @@ struct TestLinalgTileAndFuseSequencePass
op.erase();
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestLinalgFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass(
"test-linalg-fusion-transform-patterns",
"Test Linalg fusion transformation patterns by applying them greedily.");
PassRegistration<TestLinalgFusionTransformsParallelLoops>();
}
void registerTestLinalgTensorFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>>
testTensorFusionTransformsPass(
"test-linalg-tensor-fusion-transform-patterns",
"Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.");
PassRegistration<TestLinalgFusionTransformsLoops>();
}
void registerTestLinalgTiledLoopFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
testTiledLoopFusionTransformsPass(
"test-linalg-tiled-loop-fusion-transform-patterns",
"Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.");
PassRegistration<TestLinalgFusionTransformsTiledLoops>();
}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
"test-linalg-greedy-fusion",
"Test Linalg fusion by applying a greedy test transformation.");
PassRegistration<TestLinalgGreedyFusion>();
}
void registerTestLinalgTileAndFuseSequencePass() {
PassRegistration<TestLinalgTileAndFuseSequencePass>
testTileAndFuseSequencePass(
"test-linalg-tile-and-fuse",
"Test Linalg tiling and fusion of a sequence of Linalg operations.");
PassRegistration<TestLinalgTileAndFuseSequencePass>();
}
} // namespace test

View File

@ -26,6 +26,10 @@ struct TestLinalgHoisting
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
StringRef getArgument() const final { return "test-linalg-hoisting"; }
StringRef getDescription() const final {
return "Test Linalg hoisting functions.";
}
void runOnFunction() override;
@ -46,9 +50,6 @@ void TestLinalgHoisting::runOnFunction() {
namespace mlir {
namespace test {
void registerTestLinalgHoisting() {
PassRegistration<TestLinalgHoisting> testTestLinalgHoistingPass(
"test-linalg-hoisting", "Test Linalg hoisting functions.");
}
void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); }
} // namespace test
} // namespace mlir

View File

@ -42,6 +42,12 @@ struct TestLinalgTransforms
gpu::GPUDialect>();
// clang-format on
}
StringRef getArgument() const final {
return "test-linalg-transform-patterns";
}
StringRef getDescription() const final {
return "Test Linalg transformation patterns by applying them greedily.";
}
void runOnFunction() override;
@ -612,9 +618,7 @@ void TestLinalgTransforms::runOnFunction() {
namespace mlir {
namespace test {
void registerTestLinalgTransforms() {
PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
"test-linalg-transform-patterns",
"Test Linalg transformation patterns by applying them greedily.");
PassRegistration<TestLinalgTransforms>();
}
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,8 @@ namespace {
struct TestExpandTanhPass
: public PassWrapper<TestExpandTanhPass, FunctionPass> {
void runOnFunction() override;
StringRef getArgument() const final { return "test-expand-tanh"; }
StringRef getDescription() const final { return "Test expanding tanh"; }
};
} // end anonymous namespace
@ -31,9 +33,6 @@ void TestExpandTanhPass::runOnFunction() {
namespace mlir {
namespace test {
void registerTestExpandTanhPass() {
PassRegistration<TestExpandTanhPass> pass("test-expand-tanh",
"Test expanding tanh");
}
void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
} // namespace test
} // namespace mlir

View File

@ -28,6 +28,12 @@ struct TestMathPolynomialApproximationPass
registry
.insert<vector::VectorDialect, math::MathDialect, LLVM::LLVMDialect>();
}
StringRef getArgument() const final {
return "test-math-polynomial-approximation";
}
StringRef getDescription() const final {
return "Test math polynomial approximations";
}
};
} // end anonymous namespace
@ -40,9 +46,7 @@ void TestMathPolynomialApproximationPass::runOnFunction() {
namespace mlir {
namespace test {
void registerTestMathPolynomialApproximationPass() {
PassRegistration<TestMathPolynomialApproximationPass> pass(
"test-math-polynomial-approximation",
"Test math polynomial approximations");
PassRegistration<TestMathPolynomialApproximationPass>();
}
} // namespace test
} // namespace mlir

View File

@ -24,7 +24,9 @@ namespace {
class TestSCFForUtilsPass
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
public:
explicit TestSCFForUtilsPass() {}
StringRef getArgument() const final { return "test-scf-for-utils"; }
StringRef getDescription() const final { return "test scf.for utils"; }
explicit TestSCFForUtilsPass() = default;
void runOnFunction() override {
FuncOp func = getFunction();
@ -54,7 +56,9 @@ public:
class TestSCFIfUtilsPass
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
public:
explicit TestSCFIfUtilsPass() {}
StringRef getArgument() const final { return "test-scf-if-utils"; }
StringRef getDescription() const final { return "test scf.if utils"; }
explicit TestSCFIfUtilsPass() = default;
void runOnFunction() override {
int count = 0;
@ -73,10 +77,8 @@ public:
namespace mlir {
namespace test {
void registerTestSCFUtilsPass() {
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
"test scf.for utils");
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
"test scf.if utils");
PassRegistration<TestSCFForUtilsPass>();
PassRegistration<TestSCFIfUtilsPass>();
}
} // namespace test
} // namespace mlir

View File

@ -23,6 +23,10 @@ namespace {
struct PrintOpAvailability
: public PassWrapper<PrintOpAvailability, FunctionPass> {
void runOnFunction() override;
StringRef getArgument() const final { return "test-spirv-op-availability"; }
StringRef getDescription() const final {
return "Test SPIR-V op availability";
}
};
} // end anonymous namespace
@ -78,8 +82,7 @@ void PrintOpAvailability::runOnFunction() {
namespace mlir {
void registerPrintOpAvailabilityPass() {
PassRegistration<PrintOpAvailability> printOpAvailabilityPass(
"test-spirv-op-availability", "Test SPIR-V op availability");
PassRegistration<PrintOpAvailability>();
}
} // namespace mlir
@ -91,6 +94,10 @@ namespace {
/// A pass for testing SPIR-V op availability.
struct ConvertToTargetEnv
: public PassWrapper<ConvertToTargetEnv, FunctionPass> {
StringRef getArgument() const override { return "test-spirv-target-env"; }
StringRef getDescription() const override {
return "Test SPIR-V target environment";
}
void runOnFunction() override;
};
@ -225,7 +232,6 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
namespace mlir {
void registerConvertToTargetEnvPass() {
PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass(
"test-spirv-target-env", "Test SPIR-V target environment");
PassRegistration<ConvertToTargetEnv>();
}
} // namespace mlir

View File

@ -24,6 +24,12 @@ class TestSpirvEntryPointABIPass
: public PassWrapper<TestSpirvEntryPointABIPass,
OperationPass<gpu::GPUModuleOp>> {
public:
StringRef getArgument() const final { return "test-spirv-entry-point-abi"; }
StringRef getDescription() const final {
return "Set the spv.entry_point_abi attribute on GPU kernel function "
"within the "
"module, intended for testing only";
}
TestSpirvEntryPointABIPass() = default;
TestSpirvEntryPointABIPass(const TestSpirvEntryPointABIPass &) {}
void runOnOperation() override;
@ -56,9 +62,6 @@ void TestSpirvEntryPointABIPass::runOnOperation() {
namespace mlir {
void registerTestSpirvEntryPointABIPass() {
PassRegistration<TestSpirvEntryPointABIPass> registration(
"test-spirv-entry-point-abi",
"Set the spv.entry_point_abi attribute on GPU kernel function within the "
"module, intended for testing only");
PassRegistration<TestSpirvEntryPointABIPass>();
}
} // namespace mlir

View File

@ -20,6 +20,12 @@ public:
TestGLSLCanonicalizationPass() = default;
TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
void runOnOperation() override;
StringRef getArgument() const final {
return "test-spirv-glsl-canonicalization";
}
StringRef getDescription() const final {
return "Tests SPIR-V canonicalization patterns for GLSL extension.";
}
};
} // namespace
@ -31,8 +37,6 @@ void TestGLSLCanonicalizationPass::runOnOperation() {
namespace mlir {
void registerTestSpirvGLSLCanonicalizationPass() {
PassRegistration<TestGLSLCanonicalizationPass> registration(
"test-spirv-glsl-canonicalization",
"Tests SPIR-V canonicalization patterns for GLSL extension.");
PassRegistration<TestGLSLCanonicalizationPass>();
}
} // namespace mlir

View File

@ -20,6 +20,10 @@ class TestModuleCombinerPass
: public PassWrapper<TestModuleCombinerPass,
OperationPass<mlir::ModuleOp>> {
public:
StringRef getArgument() const final { return "test-spirv-module-combiner"; }
StringRef getDescription() const final {
return "Tests SPIR-V module combiner library";
}
TestModuleCombinerPass() = default;
TestModuleCombinerPass(const TestModuleCombinerPass &) {}
void runOnOperation() override;
@ -41,7 +45,6 @@ void TestModuleCombinerPass::runOnOperation() {
namespace mlir {
void registerTestSpirvModuleCombinerPass() {
PassRegistration<TestModuleCombinerPass> registration(
"test-spirv-module-combiner", "Tests SPIR-V module combiner library");
PassRegistration<TestModuleCombinerPass>();
}
} // namespace mlir

View File

@ -20,6 +20,10 @@ namespace {
struct ReportShapeFnPass
: public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
StringRef getArgument() const final { return "test-shape-function-report"; }
StringRef getDescription() const final {
return "Test pass to report associated shape functions";
}
};
} // end anonymous namespace
@ -82,8 +86,6 @@ void ReportShapeFnPass::runOnOperation() {
namespace mlir {
void registerShapeFunctionTestPasses() {
PassRegistration<ReportShapeFnPass>(
"test-shape-function-report",
"Test pass to report associated shape functions");
PassRegistration<ReportShapeFnPass>();
}
} // namespace mlir

View File

@ -20,6 +20,10 @@ using namespace mlir;
namespace {
struct TestComposeSubViewPass
: public PassWrapper<TestComposeSubViewPass, FunctionPass> {
StringRef getArgument() const final { return "test-compose-subview"; }
StringRef getDescription() const final {
return "Test combining composed subviews";
}
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override;
};
@ -39,8 +43,7 @@ void TestComposeSubViewPass::runOnFunction() {
namespace mlir {
namespace test {
void registerTestComposeSubView() {
PassRegistration<TestComposeSubViewPass> pass(
"test-compose-subview", "Test combining composed subviews");
PassRegistration<TestComposeSubViewPass>();
}
} // namespace test
} // namespace mlir

View File

@ -27,6 +27,12 @@ struct TestDecomposeCallGraphTypes
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<test::TestDialect>();
}
StringRef getArgument() const final {
return "test-decompose-call-graph-types";
}
StringRef getDescription() const final {
return "Decomposes types at call graph boundaries.";
}
void runOnOperation() override {
ModuleOp module = getOperation();
auto *context = &getContext();
@ -87,9 +93,7 @@ struct TestDecomposeCallGraphTypes
namespace mlir {
namespace test {
void registerTestDecomposeCallGraphTypes() {
PassRegistration<TestDecomposeCallGraphTypes> pass(
"test-decompose-call-graph-types",
"Decomposes types at call graph boundaries.");
PassRegistration<TestDecomposeCallGraphTypes>();
}
} // namespace test
} // namespace mlir

View File

@ -95,6 +95,8 @@ public:
};
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
StringRef getArgument() const final { return "test-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnFunction() override {
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
@ -159,6 +161,8 @@ struct TestReturnTypeDriver
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final { return "test-return-type"; }
StringRef getDescription() const final { return "Run return type functions"; }
void runOnFunction() override {
if (getFunction().getName() == "testCreateFunctions") {
@ -194,6 +198,10 @@ struct TestReturnTypeDriver
namespace {
struct TestDerivedAttributeDriver
: public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
StringRef getArgument() const final { return "test-derived-attr"; }
StringRef getDescription() const final {
return "Run test derived attributes";
}
void runOnFunction() override;
};
} // end anonymous namespace
@ -585,6 +593,10 @@ struct TestTypeConverter : public TypeConverter {
struct TestLegalizePatternDriver
: public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-legalize-patterns"; }
StringRef getDescription() const final {
return "Run test dialect legalization patterns";
}
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };
@ -733,6 +745,10 @@ struct OneVResOneVOperandOp1Converter
struct TestRemappedValue
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
StringRef getArgument() const final { return "test-remapped-value"; }
StringRef getDescription() const final {
return "Test public remapped value mechanism in ConversionPatternRewriter";
}
void runOnFunction() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
@ -776,6 +792,12 @@ struct RemoveTestDialectOps : public RewritePattern {
struct TestUnknownRootOpDriver
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
StringRef getArgument() const final {
return "test-legalize-unknown-root-patterns";
}
StringRef getDescription() const final {
return "Test public remapped value mechanism in ConversionPatternRewriter";
}
void runOnFunction() override {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<RemoveTestDialectOps>(&getContext());
@ -857,6 +879,12 @@ struct TestTypeConversionDriver
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TestDialect>();
}
StringRef getArgument() const final {
return "test-legalize-type-conversion";
}
StringRef getDescription() const final {
return "Test various type conversion functionalities in DialectConversion";
}
void runOnOperation() override {
// Initialize the type converter.
@ -999,6 +1027,10 @@ struct TestMergeSingleBlockOps
struct TestMergeBlocksPatternDriver
: public PassWrapper<TestMergeBlocksPatternDriver,
OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-merge-blocks"; }
StringRef getDescription() const final {
return "Test Merging operation in ConversionPatternRewriter";
}
void runOnOperation() override {
MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
@ -1066,6 +1098,12 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
struct TestSelectiveReplacementPatternDriver
: public PassWrapper<TestSelectiveReplacementPatternDriver,
OperationPass<>> {
StringRef getArgument() const final {
return "test-pattern-selective-replacement";
}
StringRef getDescription() const final {
return "Test selective replacement in the PatternRewriter";
}
void runOnOperation() override {
MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
@ -1083,39 +1121,24 @@ struct TestSelectiveReplacementPatternDriver
namespace mlir {
namespace test {
void registerPatternsTestPass() {
PassRegistration<TestReturnTypeDriver>("test-return-type",
"Run return type functions");
PassRegistration<TestReturnTypeDriver>();
PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
"Run test derived attributes");
PassRegistration<TestDerivedAttributeDriver>();
PassRegistration<TestPatternDriver>("test-patterns",
"Run test dialect patterns");
PassRegistration<TestPatternDriver>();
PassRegistration<TestLegalizePatternDriver>(
"test-legalize-patterns", "Run test dialect legalization patterns", [] {
return std::make_unique<TestLegalizePatternDriver>(
legalizerConversionMode);
});
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
});
PassRegistration<TestRemappedValue>(
"test-remapped-value",
"Test public remapped value mechanism in ConversionPatternRewriter");
PassRegistration<TestRemappedValue>();
PassRegistration<TestUnknownRootOpDriver>(
"test-legalize-unknown-root-patterns",
"Test public remapped value mechanism in ConversionPatternRewriter");
PassRegistration<TestUnknownRootOpDriver>();
PassRegistration<TestTypeConversionDriver>(
"test-legalize-type-conversion",
"Test various type conversion functionalities in DialectConversion");
PassRegistration<TestTypeConversionDriver>();
PassRegistration<TestMergeBlocksPatternDriver>{
"test-merge-blocks",
"Test Merging operation in ConversionPatternRewriter"};
PassRegistration<TestSelectiveReplacementPatternDriver>{
"test-pattern-selective-replacement",
"Test selective replacement in the PatternRewriter"};
PassRegistration<TestMergeBlocksPatternDriver>();
PassRegistration<TestSelectiveReplacementPatternDriver>();
}
} // namespace test
} // namespace mlir

View File

@ -32,6 +32,8 @@ OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
namespace {
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
StringRef getArgument() const final { return "test-trait-folder"; }
StringRef getDescription() const final { return "Run trait folding"; }
void runOnFunction() override {
(void)applyPatternsAndFoldGreedily(getFunction(),
RewritePatternSet(&getContext()));
@ -40,7 +42,5 @@ struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
} // end anonymous namespace
namespace mlir {
void registerTestTraitsPass() {
PassRegistration<TestTraitFolder>("test-trait-folder", "Run trait folding");
}
void registerTestTraitsPass() { PassRegistration<TestTraitFolder>(); }
} // namespace mlir

View File

@ -179,6 +179,10 @@ namespace {
struct TosaTestQuantUtilAPI
: public PassWrapper<TosaTestQuantUtilAPI, FunctionPass> {
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
}
void runOnFunction() override;
};
@ -196,7 +200,6 @@ void TosaTestQuantUtilAPI::runOnFunction() {
namespace mlir {
void registerTosaTestQuantUtilAPIPass() {
PassRegistration<TosaTestQuantUtilAPI>(
PASS_NAME, "TOSA Test: Exercise the APIs in QuantUtils.cpp.");
PassRegistration<TosaTestQuantUtilAPI>();
}
} // namespace mlir

View File

@ -27,6 +27,12 @@ struct TestVectorToVectorConversion
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
TestVectorToVectorConversion() = default;
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-conversion";
}
StringRef getDescription() const final {
return "Test conversion patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
@ -69,6 +75,13 @@ private:
struct TestVectorSlicesConversion
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-slices-conversion";
}
StringRef getDescription() const final {
return "Test conversion patterns that lower slices ops in the vector "
"dialect";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorSlicesLoweringPatterns(patterns);
@ -78,6 +91,13 @@ struct TestVectorSlicesConversion
struct TestVectorContractionConversion
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-contraction-conversion";
}
StringRef getDescription() const final {
return "Test conversion patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionConversion() = default;
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
}
@ -146,6 +166,13 @@ struct TestVectorContractionConversion
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
@ -199,6 +226,13 @@ struct TestVectorUnrollingPatterns
struct TestVectorDistributePatterns
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-distribute-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to distribute vector ops in the vector "
"dialect";
}
TestVectorDistributePatterns() = default;
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
@ -249,6 +283,10 @@ struct TestVectorDistributePatterns
struct TestVectorToLoopPatterns
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
StringRef getArgument() const final { return "test-vector-to-forloop"; }
StringRef getDescription() const final {
return "Test conversion patterns to break up a vector op into a for loop";
}
TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
@ -312,6 +350,13 @@ struct TestVectorTransferUnrollingPatterns
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
@ -332,6 +377,13 @@ struct TestVectorTransferUnrollingPatterns
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
StringRef getArgument() const final {
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test conversion patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
@ -361,6 +413,10 @@ struct TestVectorTransferFullPartialSplitPatterns
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, FunctionPass> {
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void runOnFunction() override { transferOpflowOpt(getFunction()); }
};
@ -369,6 +425,12 @@ struct TestVectorTransferLoweringPatterns
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to lower transfer ops to other vector ops";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
@ -382,6 +444,13 @@ struct TestVectorMultiReductionLoweringPatterns
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to lower vector.multi_reduction to other "
"vector ops";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(patterns);
@ -394,53 +463,27 @@ struct TestVectorMultiReductionLoweringPatterns
namespace mlir {
namespace test {
void registerTestVectorConversions() {
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
"test-vector-to-vector-conversion",
"Test conversion patterns between ops in the vector dialect");
PassRegistration<TestVectorToVectorConversion>();
PassRegistration<TestVectorSlicesConversion> slicesPass(
"test-vector-slices-conversion",
"Test conversion patterns that lower slices ops in the vector dialect");
PassRegistration<TestVectorSlicesConversion>();
PassRegistration<TestVectorContractionConversion> contractionPass(
"test-vector-contraction-conversion",
"Test conversion patterns that lower contract ops in the vector dialect");
PassRegistration<TestVectorContractionConversion>();
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
"test-vector-unrolling-patterns",
"Test conversion patterns to unroll contract ops in the vector dialect");
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
"test-vector-transfer-unrolling-patterns",
"Test conversion patterns to unroll transfer ops in the vector dialect");
PassRegistration<TestVectorTransferUnrollingPatterns>();
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
"Test conversion patterns to split "
"transfer ops via scf.if + linalg ops");
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
PassRegistration<TestVectorDistributePatterns> distributePass(
"test-vector-distribute-patterns",
"Test conversion patterns to distribute vector ops in the vector "
"dialect");
PassRegistration<TestVectorDistributePatterns>();
PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
"test-vector-to-forloop",
"Test conversion patterns to break up a vector op into a for loop");
PassRegistration<TestVectorToLoopPatterns>();
PassRegistration<TestVectorTransferOpt> transferOpOpt(
"test-vector-transferop-opt",
"Test optimization transformations for transfer ops");
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
"test-vector-transfer-lowering-patterns",
"Test conversion patterns to lower transfer ops to other vector ops");
PassRegistration<TestVectorTransferLoweringPatterns>();
PassRegistration<TestVectorMultiReductionLoweringPatterns>
multiDimReductionOpLoweringPass(
"test-vector-multi-reduction-lowering-patterns",
"Test conversion patterns to lower vector.multi_reduction to other "
"vector ops");
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
}
} // namespace test
} // namespace mlir

View File

@ -91,6 +91,10 @@ private:
};
struct TestDominancePass : public PassWrapper<TestDominancePass, FunctionPass> {
StringRef getArgument() const final { return "test-print-dominance"; }
StringRef getDescription() const final {
return "Print the dominance information for multiple regions.";
}
void runOnFunction() override {
llvm::errs() << "Testing : " << getFunction().getName() << "\n";
@ -120,10 +124,6 @@ struct TestDominancePass : public PassWrapper<TestDominancePass, FunctionPass> {
namespace mlir {
namespace test {
void registerTestDominancePass() {
PassRegistration<TestDominancePass>(
"test-print-dominance",
"Print the dominance information for multiple regions.");
}
void registerTestDominancePass() { PassRegistration<TestDominancePass>(); }
} // namespace test
} // namespace mlir

View File

@ -15,6 +15,8 @@ namespace {
/// This is a test pass for verifying FuncOp's eraseArgument method.
struct TestFuncEraseArg
: public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-erase-arg"; }
StringRef getDescription() const final { return "Test erasing func args."; }
void runOnOperation() override {
auto module = getOperation();
@ -39,21 +41,28 @@ struct TestFuncEraseArg
/// This is a test pass for verifying FuncOp's eraseResult method.
struct TestFuncEraseResult
: public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-erase-result"; }
StringRef getDescription() const final {
return "Test erasing func results.";
}
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
// Push back twice to test that duplicate indices are handled
// correctly.
if (func.getResultAttr(resultIndex, "test.erase_this_"
"result")) {
// Push back twice to test
// that duplicate indices
// are handled correctly.
indicesToErase.push_back(resultIndex);
indicesToErase.push_back(resultIndex);
}
}
// Reverse the order to test that unsorted index lists are handled
// correctly.
// Reverse the order to test
// that unsorted index lists are
// handled correctly.
std::reverse(indicesToErase.begin(), indicesToErase.end());
func.eraseResults(indicesToErase);
}
@ -63,6 +72,8 @@ struct TestFuncEraseResult
/// This is a test pass for verifying FuncOp's setType method.
struct TestFuncSetType
: public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-set-type"; }
StringRef getDescription() const final { return "Test FuncOp::setType."; }
void runOnOperation() override {
auto module = getOperation();
SymbolTable symbolTable(module);
@ -79,13 +90,10 @@ struct TestFuncSetType
namespace mlir {
void registerTestFunc() {
PassRegistration<TestFuncEraseArg>("test-func-erase-arg",
"Test erasing func args.");
PassRegistration<TestFuncEraseArg>();
PassRegistration<TestFuncEraseResult>("test-func-erase-result",
"Test erasing func results.");
PassRegistration<TestFuncEraseResult>();
PassRegistration<TestFuncSetType>("test-func-set-type",
"Test FuncOp::setType.");
PassRegistration<TestFuncSetType>();
}
} // namespace mlir

View File

@ -17,6 +17,10 @@ namespace {
/// application.
struct TestTypeInterfaces
: public PassWrapper<TestTypeInterfaces, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-type-interfaces"; }
StringRef getDescription() const final {
return "Test type interface support.";
}
void runOnOperation() override {
getOperation().walk([](Operation *op) {
for (Type type : op->getResultTypes()) {
@ -40,9 +44,6 @@ struct TestTypeInterfaces
namespace mlir {
namespace test {
void registerTestInterfaces() {
PassRegistration<TestTypeInterfaces> pass("test-type-interfaces",
"Test type interface support.");
}
void registerTestInterfaces() { PassRegistration<TestTypeInterfaces>(); }
} // namespace test
} // namespace mlir

View File

@ -17,6 +17,10 @@ namespace {
/// This is a test pass for verifying matchers.
struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> {
void runOnFunction() override;
StringRef getArgument() const final { return "test-matchers"; }
StringRef getDescription() const final {
return "Test C++ pattern matchers.";
}
};
} // end anonymous namespace
@ -148,7 +152,5 @@ void TestMatchers::runOnFunction() {
}
namespace mlir {
void registerTestMatchers() {
PassRegistration<TestMatchers>("test-matchers", "Test C++ pattern matchers.");
}
void registerTestMatchers() { PassRegistration<TestMatchers>(); }
} // namespace mlir

View File

@ -18,6 +18,10 @@ namespace {
/// locations.
struct TestOpaqueLoc
: public PassWrapper<TestOpaqueLoc, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-opaque-loc"; }
StringRef getDescription() const final {
return "Changes all leaf locations to opaque locations";
}
/// A simple structure which is used for testing as an underlying location in
/// OpaqueLoc.
@ -82,9 +86,6 @@ struct TestOpaqueLoc
namespace mlir {
namespace test {
void registerTestOpaqueLoc() {
PassRegistration<TestOpaqueLoc> pass(
"test-opaque-loc", "Changes all leaf locations to opaque locations");
}
void registerTestOpaqueLoc() { PassRegistration<TestOpaqueLoc>(); }
} // namespace test
} // namespace mlir

View File

@ -16,6 +16,8 @@ namespace {
/// This pass illustrates the IR def-use chains through printing.
struct TestPrintDefUsePass
: public PassWrapper<TestPrintDefUsePass, OperationPass<>> {
StringRef getArgument() const final { return "test-print-defuse"; }
StringRef getDescription() const final { return "Test various printing."; }
void runOnOperation() override {
// Recursively traverse the IR nested under the current operation and print
// every single operation and their operands and users.
@ -64,8 +66,5 @@ struct TestPrintDefUsePass
} // end anonymous namespace
namespace mlir {
void registerTestPrintDefUsePass() {
PassRegistration<TestPrintDefUsePass>("test-print-defuse",
"Test various printing.");
}
void registerTestPrintDefUsePass() { PassRegistration<TestPrintDefUsePass>(); }
} // namespace mlir

View File

@ -16,6 +16,8 @@ namespace {
/// This pass illustrates the IR nesting through printing.
struct TestPrintNestingPass
: public PassWrapper<TestPrintNestingPass, OperationPass<>> {
StringRef getArgument() const final { return "test-print-nesting"; }
StringRef getDescription() const final { return "Test various printing."; }
// Entry point for the pass.
void runOnOperation() override {
Operation *op = getOperation();
@ -90,7 +92,6 @@ struct TestPrintNestingPass
namespace mlir {
void registerTestPrintNestingPass() {
PassRegistration<TestPrintNestingPass>("test-print-nesting",
"Test various printing.");
PassRegistration<TestPrintNestingPass>();
}
} // namespace mlir

View File

@ -14,6 +14,10 @@ using namespace mlir;
namespace {
struct SideEffectsPass
: public PassWrapper<SideEffectsPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-side-effects"; }
StringRef getDescription() const final {
return "Test side effects interfaces";
}
void runOnOperation() override {
auto module = getOperation();
@ -68,8 +72,5 @@ struct SideEffectsPass
} // end anonymous namespace
namespace mlir {
void registerSideEffectTestPasses() {
PassRegistration<SideEffectsPass>("test-side-effects",
"Test side effects interfaces");
}
void registerSideEffectTestPasses() { PassRegistration<SideEffectsPass>(); }
} // namespace mlir

View File

@ -47,6 +47,10 @@ namespace {
/// Pass to test slice generated from slice analysis.
struct SliceAnalysisTestPass
: public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "slice-analysis-test"; }
StringRef getDescription() const final {
return "Test Slice analysis functionality.";
}
void runOnOperation() override;
SliceAnalysisTestPass() = default;
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
@ -74,7 +78,6 @@ void SliceAnalysisTestPass::runOnOperation() {
namespace mlir {
void registerSliceAnalysisTestPass() {
PassRegistration<SliceAnalysisTestPass> pass(
"slice-analysis-test", "Test Slice analysis functionality.");
PassRegistration<SliceAnalysisTestPass>();
}
} // namespace mlir

View File

@ -17,6 +17,10 @@ namespace {
/// provided by the symbol table along with erasing from the symbol table.
struct SymbolUsesPass
: public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-symbol-uses"; }
StringRef getDescription() const final {
return "Test detection of symbol uses";
}
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
SmallVectorImpl<FuncOp> &deadFunctions) {
// Test computing uses on a non symboltable op.
@ -89,6 +93,10 @@ struct SymbolUsesPass
/// functionality provided by the symbol table.
struct SymbolReplacementPass
: public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-symbol-rauw"; }
StringRef getDescription() const final {
return "Test replacement of symbol uses";
}
void runOnOperation() override {
ModuleOp module = getOperation();
@ -111,10 +119,8 @@ struct SymbolReplacementPass
namespace mlir {
void registerSymbolTestPasses() {
PassRegistration<SymbolUsesPass>("test-symbol-uses",
"Test detection of symbol uses");
PassRegistration<SymbolUsesPass>();
PassRegistration<SymbolReplacementPass>("test-symbol-rauw",
"Test replacement of symbol uses");
PassRegistration<SymbolReplacementPass>();
}
} // namespace mlir

View File

@ -18,6 +18,10 @@ struct TestRecursiveTypesPass
: public PassWrapper<TestRecursiveTypesPass, FunctionPass> {
LogicalResult createIRWithTypes();
StringRef getArgument() const final { return "test-recursive-types"; }
StringRef getDescription() const final {
return "Test support for recursive types";
}
void runOnFunction() override {
FuncOp func = getFunction();
@ -73,8 +77,7 @@ namespace mlir {
namespace test {
void registerTestRecursiveTypesPass() {
PassRegistration<TestRecursiveTypesPass> reg(
"test-recursive-types", "Test support for recursive types");
PassRegistration<TestRecursiveTypesPass>();
}
} // namespace test

View File

@ -152,6 +152,8 @@ namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
: public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
StringRef getArgument() const final { return "test-ir-visitors"; }
StringRef getDescription() const final { return "Test various visitors."; }
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
@ -163,9 +165,6 @@ struct TestIRVisitorsPass
namespace mlir {
namespace test {
void registerTestIRVisitorsPass() {
PassRegistration<TestIRVisitorsPass>("test-ir-visitors",
"Test various visitors.");
}
void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,11 @@ namespace {
class TestDynamicPipelinePass
: public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
public:
StringRef getArgument() const final { return "test-dynamic-pipeline"; }
StringRef getDescription() const final {
return "Tests the dynamic pipeline feature by applying "
"a pipeline on a selected set of functions";
}
void getDependentDialects(DialectRegistry &registry) const override {
OpPassManager pm(ModuleOp::getOperationName(),
OpPassManager::Nesting::Implicit);
@ -106,9 +111,7 @@ public:
namespace mlir {
namespace test {
void registerTestDynamicPipelinePass() {
PassRegistration<TestDynamicPipelinePass>(
"test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
"a pipeline on a selected set of functions");
PassRegistration<TestDynamicPipelinePass>();
}
} // namespace test
} // namespace mlir

View File

@ -17,10 +17,16 @@ struct TestModulePass
: public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
void runOnOperation() final {}
StringRef getArgument() const final { return "test-module-pass"; }
StringRef getDescription() const final {
return "Test a module pass in the pass manager";
}
};
struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
void runOnFunction() final {}
StringRef getArgument() const final { return "test-function-pass"; }
StringRef getDescription() const final {
return "Test a function pass in the pass manager";
}
};
class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
public:
@ -44,6 +50,9 @@ public:
void runOnFunction() final {}
StringRef getArgument() const final { return "test-options-pass"; }
StringRef getDescription() const final {
return "Test options parsing capabilities";
}
ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Example list option")};
@ -60,12 +69,19 @@ class TestCrashRecoveryPass
: public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
void runOnOperation() final { abort(); }
StringRef getArgument() const final { return "test-pass-crash"; }
StringRef getDescription() const final {
return "Test a pass in the pass manager that always crashes";
}
};
/// A test pass that always fails to enable testing the failure recovery
/// mechanisms of the pass manager.
class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
void runOnOperation() final { signalPassFailure(); }
StringRef getArgument() const final { return "test-pass-failure"; }
StringRef getDescription() const final {
return "Test a pass in the pass manager that always fails";
}
};
/// A test pass that contains a statistic.
@ -73,6 +89,8 @@ struct TestStatisticPass
: public PassWrapper<TestStatisticPass, OperationPass<>> {
TestStatisticPass() = default;
TestStatisticPass(const TestStatisticPass &) {}
StringRef getArgument() const final { return "test-stats-pass"; }
StringRef getDescription() const final { return "Test pass statistics"; }
Statistic opCount{this, "num-ops", "Number of operations counted"};
@ -102,22 +120,16 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
namespace mlir {
void registerPassManagerTestPass() {
PassRegistration<TestOptionsPass>("test-options-pass",
"Test options parsing capabilities");
PassRegistration<TestOptionsPass>();
PassRegistration<TestModulePass>("test-module-pass",
"Test a module pass in the pass manager");
PassRegistration<TestModulePass>();
PassRegistration<TestFunctionPass>(
"test-function-pass", "Test a function pass in the pass manager");
PassRegistration<TestFunctionPass>();
PassRegistration<TestCrashRecoveryPass>(
"test-pass-crash", "Test a pass in the pass manager that always crashes");
PassRegistration<TestFailurePass>(
"test-pass-failure", "Test a pass in the pass manager that always fails");
PassRegistration<TestCrashRecoveryPass>();
PassRegistration<TestFailurePass>();
PassRegistration<TestStatisticPass> unusedStatP("test-stats-pass",
"Test pass statistics");
PassRegistration<TestStatisticPass>();
PassPipelineRegistration<>("test-pm-nested-pipeline",
"Test a nested pipeline in the pass manager",

View File

@ -26,6 +26,10 @@ namespace {
/// "crashOp" in the input MLIR file and crashes the mlir-opt tool if the
/// operation is found.
struct TestReducer : public PassWrapper<TestReducer, FunctionPass> {
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests MLIR Reduce tool by generating failures";
}
TestReducer() = default;
TestReducer(const TestReducer &pass){};
void runOnFunction() override;
@ -47,8 +51,5 @@ void TestReducer::runOnFunction() {
}
namespace mlir {
void registerTestReducer() {
PassRegistration<TestReducer>(
PASS_NAME, "Tests MLIR Reduce tool by generating failures");
}
void registerTestReducer() { PassRegistration<TestReducer>(); }
} // namespace mlir

View File

@ -71,6 +71,10 @@ static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
namespace {
struct TestPDLByteCodePass
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
StringRef getDescription() const final {
return "Test PDL ByteCode functionality";
}
void runOnOperation() final {
ModuleOp module = getOperation();
@ -107,9 +111,6 @@ struct TestPDLByteCodePass
namespace mlir {
namespace test {
void registerTestPDLByteCodePass() {
PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
"Test PDL ByteCode functionality");
}
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
} // namespace test
} // namespace mlir

View File

@ -16,6 +16,10 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
struct TestConstantFold : public PassWrapper<TestConstantFold, FunctionPass> {
StringRef getArgument() const final { return "test-constant-fold"; }
StringRef getDescription() const final {
return "Test operation constant folding";
}
// All constants in the function post folding.
SmallVector<Operation *, 8> existingConstants;
@ -62,9 +66,6 @@ void TestConstantFold::runOnFunction() {
namespace mlir {
namespace test {
void registerTestConstantFold() {
PassRegistration<TestConstantFold>("test-constant-fold",
"Test operation constant folding");
}
void registerTestConstantFold() { PassRegistration<TestConstantFold>(); }
} // namespace test
} // namespace mlir

View File

@ -26,6 +26,11 @@ using namespace mlir::test;
namespace {
struct Inliner : public PassWrapper<Inliner, FunctionPass> {
StringRef getArgument() const final { return "test-inline"; }
StringRef getDescription() const final {
return "Test inlining region calls";
}
void runOnFunction() override {
auto function = getFunction();
@ -63,8 +68,6 @@ struct Inliner : public PassWrapper<Inliner, FunctionPass> {
namespace mlir {
namespace test {
void registerInliner() {
PassRegistration<Inliner>("test-inline", "Test inlining region calls");
}
void registerInliner() { PassRegistration<Inliner>(); }
} // namespace test
} // namespace mlir

View File

@ -42,6 +42,10 @@ static llvm::cl::opt<bool> clTestLoopFusionTransformation(
namespace {
struct TestLoopFusion : public PassWrapper<TestLoopFusion, FunctionPass> {
StringRef getArgument() const final { return "test-loop-fusion"; }
StringRef getDescription() const final {
return "Tests loop fusion utility functions.";
}
void runOnFunction() override;
};
@ -198,9 +202,6 @@ void TestLoopFusion::runOnFunction() {
namespace mlir {
namespace test {
void registerTestLoopFusion() {
PassRegistration<TestLoopFusion>("test-loop-fusion",
"Tests loop fusion utility functions.");
}
void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
} // namespace test
} // namespace mlir

View File

@ -26,6 +26,12 @@ namespace {
class TestLoopMappingPass
: public PassWrapper<TestLoopMappingPass, FunctionPass> {
public:
StringRef getArgument() const final {
return "test-mapping-to-processing-elements";
}
StringRef getDescription() const final {
return "test mapping a single loop on a virtual processor grid";
}
explicit TestLoopMappingPass() {}
void getDependentDialects(DialectRegistry &registry) const override {
@ -58,10 +64,6 @@ public:
namespace mlir {
namespace test {
void registerTestLoopMappingPass() {
PassRegistration<TestLoopMappingPass>(
"test-mapping-to-processing-elements",
"test mapping a single loop on a virtual processor grid");
}
void registerTestLoopMappingPass() { PassRegistration<TestLoopMappingPass>(); }
} // namespace test
} // namespace mlir

View File

@ -25,6 +25,14 @@ namespace {
class SimpleParametricLoopTilingPass
: public PassWrapper<SimpleParametricLoopTilingPass, FunctionPass> {
public:
StringRef getArgument() const final {
return "test-extract-fixed-outer-loops";
}
StringRef getDescription() const final {
return "test application of parametric tiling to the outer loops so that "
"the "
"ranges of outer loops become static";
}
SimpleParametricLoopTilingPass() = default;
SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {}
explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) {
@ -51,10 +59,7 @@ public:
namespace mlir {
namespace test {
void registerSimpleParametricTilingPass() {
PassRegistration<SimpleParametricLoopTilingPass>(
"test-extract-fixed-outer-loops",
"test application of parametric tiling to the outer loops so that the "
"ranges of outer loops become static");
PassRegistration<SimpleParametricLoopTilingPass>();
}
} // namespace test
} // namespace mlir

View File

@ -33,6 +33,10 @@ static unsigned getNestingDepth(Operation *op) {
class TestLoopUnrollingPass
: public PassWrapper<TestLoopUnrollingPass, FunctionPass> {
public:
StringRef getArgument() const final { return "test-loop-unrolling"; }
StringRef getDescription() const final {
return "Tests loop unrolling transformation";
}
TestLoopUnrollingPass() = default;
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
@ -65,8 +69,7 @@ public:
namespace mlir {
namespace test {
void registerTestLoopUnrollingPass() {
PassRegistration<TestLoopUnrollingPass>(
"test-loop-unrolling", "Tests loop unrolling transformation");
PassRegistration<TestLoopUnrollingPass>();
}
} // namespace test
} // namespace mlir