forked from OSchip/llvm-project
Migrate MLIR test passes to the new registration API
Make sure they all define getArgument()/getDescription(). Depends On D104421 Differential Revision: https://reviews.llvm.org/D104426
This commit is contained in:
parent
c8a3f561eb
commit
b5e22e6d42
|
@ -47,6 +47,11 @@ class SerializeToCubinPass
|
|||
public:
|
||||
SerializeToCubinPass();
|
||||
|
||||
StringRef getArgument() const override { return "gpu-to-cubin"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Lower GPU kernel function to CUBIN binary annotations";
|
||||
}
|
||||
|
||||
private:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override;
|
||||
|
||||
|
@ -126,7 +131,6 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
|
|||
// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
|
||||
void mlir::registerGpuSerializeToCubinPass() {
|
||||
PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
|
||||
"gpu-to-cubin", "Lower GPU kernel function to CUBIN binary annotations",
|
||||
[] {
|
||||
// Initialize LLVM NVPTX backend.
|
||||
LLVMInitializeNVPTXTarget();
|
||||
|
|
|
@ -50,6 +50,11 @@ class SerializeToHsacoPass
|
|||
public:
|
||||
SerializeToHsacoPass();
|
||||
|
||||
StringRef getArgument() const override { return "gpu-to-hsaco"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Lower GPU kernel function to HSACO binary annotations";
|
||||
}
|
||||
|
||||
private:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override;
|
||||
|
||||
|
@ -268,7 +273,6 @@ SerializeToHsacoPass::serializeISA(const std::string &isa) {
|
|||
// Register pass to serialize GPU kernel functions to a HSACO binary annotation.
|
||||
void mlir::registerGpuSerializeToHsacoPass() {
|
||||
PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
|
||||
"gpu-to-hsaco", "Lower GPU kernel function to HSACO binary annotations",
|
||||
[] {
|
||||
// Initialize LLVM AMDGPU backend.
|
||||
LLVMInitializeAMDGPUAsmParser();
|
||||
|
|
|
@ -46,6 +46,10 @@ static void printAliasOperand(Value value) {
|
|||
namespace {
|
||||
struct TestAliasAnalysisPass
|
||||
: public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
|
||||
StringRef getArgument() const final { return "test-alias-analysis"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test alias analysis results.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
|
||||
|
||||
|
@ -84,6 +88,10 @@ struct TestAliasAnalysisPass
|
|||
namespace {
|
||||
struct TestAliasAnalysisModRefPass
|
||||
: public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
|
||||
StringRef getArgument() const final { return "test-alias-analysis-modref"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test alias analysis ModRef results.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
|
||||
|
||||
|
@ -126,10 +134,8 @@ struct TestAliasAnalysisModRefPass
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAliasAnalysisPass() {
|
||||
PassRegistration<TestAliasAnalysisPass> aliasPass(
|
||||
"test-alias-analysis", "Test alias analysis results.");
|
||||
PassRegistration<TestAliasAnalysisModRefPass> modRefPass(
|
||||
"test-alias-analysis-modref", "Test alias analysis ModRef results.");
|
||||
PassRegistration<TestAliasAnalysisPass>();
|
||||
PassRegistration<TestAliasAnalysisModRefPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -19,6 +19,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct TestCallGraphPass
|
||||
: public PassWrapper<TestCallGraphPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-print-callgraph"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Print the contents of a constructed callgraph.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
llvm::errs() << "Testing : " << getOperation()->getAttr("test.name")
|
||||
<< "\n";
|
||||
|
@ -29,9 +33,6 @@ struct TestCallGraphPass
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestCallGraphPass() {
|
||||
PassRegistration<TestCallGraphPass> pass(
|
||||
"test-print-callgraph", "Print the contents of a constructed callgraph.");
|
||||
}
|
||||
void registerTestCallGraphPass() { PassRegistration<TestCallGraphPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
struct TestLivenessPass : public PassWrapper<TestLivenessPass, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-print-liveness"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Print the contents of a constructed liveness information.";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
llvm::errs() << "Testing : " << getFunction().getName() << "\n";
|
||||
getAnalysis<Liveness>().print(llvm::errs());
|
||||
|
@ -30,10 +34,6 @@ struct TestLivenessPass : public PassWrapper<TestLivenessPass, FunctionPass> {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLivenessPass() {
|
||||
PassRegistration<TestLivenessPass>(
|
||||
"test-print-liveness",
|
||||
"Print the contents of a constructed liveness information.");
|
||||
}
|
||||
void registerTestLivenessPass() { PassRegistration<TestLivenessPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -29,6 +29,10 @@ namespace {
|
|||
/// Checks for out of bound memref access subscripts..
|
||||
struct TestMemRefBoundCheck
|
||||
: public PassWrapper<TestMemRefBoundCheck, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-memref-bound-check"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Check memref access bounds in a Function";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -46,9 +50,6 @@ void TestMemRefBoundCheck::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerMemRefBoundCheck() {
|
||||
PassRegistration<TestMemRefBoundCheck>(
|
||||
"test-memref-bound-check", "Check memref access bounds in a Function");
|
||||
}
|
||||
void registerMemRefBoundCheck() { PassRegistration<TestMemRefBoundCheck>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -28,6 +28,10 @@ namespace {
|
|||
/// Checks dependences between all pairs of memref accesses in a Function.
|
||||
struct TestMemRefDependenceCheck
|
||||
: public PassWrapper<TestMemRefDependenceCheck, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-memref-dependence-check"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Checks dependences between all pairs of memref accesses.";
|
||||
}
|
||||
SmallVector<Operation *, 4> loadsAndStores;
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
@ -112,9 +116,7 @@ void TestMemRefDependenceCheck::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMemRefDependenceCheck() {
|
||||
PassRegistration<TestMemRefDependenceCheck> pass(
|
||||
"test-memref-dependence-check",
|
||||
"Checks dependences between all pairs of memref accesses.");
|
||||
PassRegistration<TestMemRefDependenceCheck>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,6 +15,12 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct TestMemRefStrideCalculation
|
||||
: public PassWrapper<TestMemRefStrideCalculation, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-memref-stride-calculation";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test operation constant folding";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -51,8 +57,7 @@ void TestMemRefStrideCalculation::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMemRefStrideCalculation() {
|
||||
PassRegistration<TestMemRefStrideCalculation> pass(
|
||||
"test-memref-stride-calculation", "Test operation constant folding");
|
||||
PassRegistration<TestMemRefStrideCalculation>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,14 @@ namespace {
|
|||
|
||||
struct TestNumberOfBlockExecutionsPass
|
||||
: public PassWrapper<TestNumberOfBlockExecutionsPass, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-print-number-of-block-executions";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Print the contents of a constructed number of executions analysis "
|
||||
"for "
|
||||
"all blocks.";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
|
||||
getAnalysis<NumberOfExecutions>().printBlockExecutions(
|
||||
|
@ -29,6 +37,14 @@ struct TestNumberOfBlockExecutionsPass
|
|||
|
||||
struct TestNumberOfOperationExecutionsPass
|
||||
: public PassWrapper<TestNumberOfOperationExecutionsPass, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-print-number-of-operation-executions";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Print the contents of a constructed number of executions analysis "
|
||||
"for "
|
||||
"all operations.";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
|
||||
getAnalysis<NumberOfExecutions>().printOperationExecutions(
|
||||
|
@ -41,17 +57,11 @@ struct TestNumberOfOperationExecutionsPass
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestNumberOfBlockExecutionsPass() {
|
||||
PassRegistration<TestNumberOfBlockExecutionsPass>(
|
||||
"test-print-number-of-block-executions",
|
||||
"Print the contents of a constructed number of executions analysis for "
|
||||
"all blocks.");
|
||||
PassRegistration<TestNumberOfBlockExecutionsPass>();
|
||||
}
|
||||
|
||||
void registerTestNumberOfOperationExecutionsPass() {
|
||||
PassRegistration<TestNumberOfOperationExecutionsPass>(
|
||||
"test-print-number-of-operation-executions",
|
||||
"Print the contents of a constructed number of executions analysis for "
|
||||
"all operations.");
|
||||
PassRegistration<TestNumberOfOperationExecutionsPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -38,6 +38,11 @@ public:
|
|||
void getDependentDialects(DialectRegistry ®istry) const final {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-convert-call-op"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests conversion of `std.call` to `llvm.call` in "
|
||||
"presence of custom types";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp m = getOperation();
|
||||
|
@ -68,11 +73,6 @@ public:
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerConvertCallOpPass() {
|
||||
PassRegistration<TestConvertCallOp>(
|
||||
"test-convert-call-op",
|
||||
"Tests conversion of `std.call` to `llvm.call` in "
|
||||
"presence of custom types");
|
||||
}
|
||||
void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -29,6 +29,10 @@ namespace {
|
|||
|
||||
struct TestAffineDataCopy
|
||||
: public PassWrapper<TestAffineDataCopy, FunctionPass> {
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests affine data copy utility functions.";
|
||||
}
|
||||
TestAffineDataCopy() = default;
|
||||
TestAffineDataCopy(const TestAffineDataCopy &pass){};
|
||||
|
||||
|
@ -128,7 +132,6 @@ void TestAffineDataCopy::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestAffineDataCopyPass() {
|
||||
PassRegistration<TestAffineDataCopy>(
|
||||
PASS_NAME, "Tests affine data copy utility functions.");
|
||||
PassRegistration<TestAffineDataCopy>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -22,6 +22,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct TestAffineLoopParametricTiling
|
||||
: public PassWrapper<TestAffineLoopParametricTiling, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-affine-parametric-tile"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tile affine loops using SSA values as tile sizes";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -83,9 +87,7 @@ void TestAffineLoopParametricTiling::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAffineLoopParametricTilingPass() {
|
||||
PassRegistration<TestAffineLoopParametricTiling>(
|
||||
"test-affine-parametric-tile",
|
||||
"Tile affine loops using SSA values as tile sizes");
|
||||
PassRegistration<TestAffineLoopParametricTiling>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -25,6 +25,10 @@ namespace {
|
|||
/// This pass applies the permutation on the first maximal perfect nest.
|
||||
struct TestAffineLoopUnswitching
|
||||
: public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests affine loop unswitching / if/else hoisting";
|
||||
}
|
||||
TestAffineLoopUnswitching() = default;
|
||||
TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
|
||||
|
||||
|
@ -54,7 +58,6 @@ void TestAffineLoopUnswitching::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestAffineLoopUnswitchingPass() {
|
||||
PassRegistration<TestAffineLoopUnswitching>(
|
||||
PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
|
||||
PassRegistration<TestAffineLoopUnswitching>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -27,6 +27,10 @@ namespace {
|
|||
/// This pass applies the permutation on the first maximal perfect nest.
|
||||
struct TestLoopPermutation
|
||||
: public PassWrapper<TestLoopPermutation, FunctionPass> {
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests affine loop permutation utility";
|
||||
}
|
||||
TestLoopPermutation() = default;
|
||||
TestLoopPermutation(const TestLoopPermutation &pass){};
|
||||
|
||||
|
@ -62,7 +66,6 @@ void TestLoopPermutation::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestLoopPermutationPass() {
|
||||
PassRegistration<TestLoopPermutation>(
|
||||
PASS_NAME, "Tests affine loop permutation utility");
|
||||
PassRegistration<TestLoopPermutation>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -75,6 +75,10 @@ struct VectorizerTestPass
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<vector::VectorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "affine-super-vectorizer-test"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests vectorizer standalone functionality.";
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
void testVectorShapeRatio(llvm::raw_ostream &outs);
|
||||
|
@ -269,9 +273,5 @@ void VectorizerTestPass::runOnFunction() {
|
|||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerVectorizerTestPass() {
|
||||
PassRegistration<VectorizerTestPass> pass(
|
||||
"affine-super-vectorizer-test",
|
||||
"Tests vectorizer standalone functionality.");
|
||||
}
|
||||
void registerVectorizerTestPass() { PassRegistration<VectorizerTestPass>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -21,6 +21,8 @@ namespace {
|
|||
/// result types.
|
||||
struct TestDataLayoutQuery
|
||||
: public PassWrapper<TestDataLayoutQuery, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-data-layout-query"; }
|
||||
StringRef getDescription() const final { return "Test data layout queries"; }
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
Builder builder(func.getContext());
|
||||
|
@ -48,9 +50,6 @@ struct TestDataLayoutQuery
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDataLayoutQuery() {
|
||||
PassRegistration<TestDataLayoutQuery>("test-data-layout-query",
|
||||
"Test data layout queries");
|
||||
}
|
||||
void registerTestDataLayoutQuery() { PassRegistration<TestDataLayoutQuery>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ namespace {
|
|||
class TestSerializeToCubinPass
|
||||
: public PassWrapper<TestSerializeToCubinPass, gpu::SerializeToBlobPass> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-gpu-to-cubin"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Lower GPU kernel function to CUBIN binary annotations";
|
||||
}
|
||||
TestSerializeToCubinPass();
|
||||
|
||||
private:
|
||||
|
@ -53,17 +57,15 @@ namespace mlir {
|
|||
namespace test {
|
||||
// Register test pass to serialize GPU module to a CUBIN binary annotation.
|
||||
void registerTestGpuSerializeToCubinPass() {
|
||||
PassRegistration<TestSerializeToCubinPass> registerSerializeToCubin(
|
||||
"test-gpu-to-cubin",
|
||||
"Lower GPU kernel function to CUBIN binary annotations", [] {
|
||||
// Initialize LLVM NVPTX backend.
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
PassRegistration<TestSerializeToCubinPass>([] {
|
||||
// Initialize LLVM NVPTX backend.
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
|
||||
return std::make_unique<TestSerializeToCubinPass>();
|
||||
});
|
||||
return std::make_unique<TestSerializeToCubinPass>();
|
||||
});
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ namespace {
|
|||
class TestSerializeToHsacoPass
|
||||
: public PassWrapper<TestSerializeToHsacoPass, gpu::SerializeToBlobPass> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-gpu-to-hsaco"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Lower GPU kernel function to HSAco binary annotations";
|
||||
}
|
||||
TestSerializeToHsacoPass();
|
||||
|
||||
private:
|
||||
|
@ -52,17 +56,15 @@ namespace mlir {
|
|||
namespace test {
|
||||
// Register test pass to serialize GPU module to a HSAco binary annotation.
|
||||
void registerTestGpuSerializeToHsacoPass() {
|
||||
PassRegistration<TestSerializeToHsacoPass> registerSerializeToHsaco(
|
||||
"test-gpu-to-hsaco",
|
||||
"Lower GPU kernel function to HSAco binary annotations", [] {
|
||||
// Initialize LLVM AMDGPU backend.
|
||||
LLVMInitializeAMDGPUTarget();
|
||||
LLVMInitializeAMDGPUTargetInfo();
|
||||
LLVMInitializeAMDGPUTargetMC();
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
PassRegistration<TestSerializeToHsacoPass>([] {
|
||||
// Initialize LLVM AMDGPU backend.
|
||||
LLVMInitializeAMDGPUTarget();
|
||||
LLVMInitializeAMDGPUTargetInfo();
|
||||
LLVMInitializeAMDGPUTargetMC();
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
|
||||
return std::make_unique<TestSerializeToHsacoPass>();
|
||||
});
|
||||
return std::make_unique<TestSerializeToHsacoPass>();
|
||||
});
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -35,6 +35,10 @@ class TestGpuMemoryPromotionPass
|
|||
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect,
|
||||
scf::SCFDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-gpu-memory-promotion"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Promotes the annotated arguments of gpu.func to workgroup memory.";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUFuncOp op = getOperation();
|
||||
|
@ -48,8 +52,6 @@ class TestGpuMemoryPromotionPass
|
|||
|
||||
namespace mlir {
|
||||
void registerTestGpuMemoryPromotionPass() {
|
||||
PassRegistration<TestGpuMemoryPromotionPass>(
|
||||
"test-gpu-memory-promotion",
|
||||
"Promotes the annotated arguments of gpu.func to workgroup memory.");
|
||||
PassRegistration<TestGpuMemoryPromotionPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -22,6 +22,12 @@ namespace {
|
|||
class TestGpuGreedyParallelLoopMappingPass
|
||||
: public PassWrapper<TestGpuGreedyParallelLoopMappingPass,
|
||||
OperationPass<FuncOp>> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-gpu-greedy-parallel-loop-mapping";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Greedily maps all parallel loops to gpu hardware ids.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
for (Region ®ion : op->getRegions())
|
||||
|
@ -33,9 +39,7 @@ class TestGpuGreedyParallelLoopMappingPass
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestGpuParallelLoopMappingPass() {
|
||||
PassRegistration<TestGpuGreedyParallelLoopMappingPass> registration(
|
||||
"test-gpu-greedy-parallel-loop-mapping",
|
||||
"Greedily maps all parallel loops to gpu hardware ids.");
|
||||
PassRegistration<TestGpuGreedyParallelLoopMappingPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -24,6 +24,10 @@ struct TestGpuRewritePass
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect, memref::MemRefDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-gpu-rewrite"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Applies all rewrite patterns within the GPU dialect.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateGpuRewritePatterns(patterns);
|
||||
|
@ -34,8 +38,6 @@ struct TestGpuRewritePass
|
|||
|
||||
namespace mlir {
|
||||
void registerTestAllReduceLoweringPass() {
|
||||
PassRegistration<TestGpuRewritePass> pass(
|
||||
"test-gpu-rewrite",
|
||||
"Applies all rewrite patterns within the GPU dialect.");
|
||||
PassRegistration<TestGpuRewritePass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -26,6 +26,10 @@ namespace {
|
|||
class TestConvVectorization
|
||||
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-conv-vectorization"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test vectorization of convolutions";
|
||||
}
|
||||
TestConvVectorization() = default;
|
||||
TestConvVectorization(const TestConvVectorization &) {}
|
||||
explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
|
||||
|
@ -129,8 +133,7 @@ void TestConvVectorization::runOnOperation() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestConvVectorization() {
|
||||
PassRegistration<TestConvVectorization> testTransformPatternsPass(
|
||||
"test-conv-vectorization", "Test vectorization of convolutions");
|
||||
PassRegistration<TestConvVectorization>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -28,6 +28,10 @@ using namespace mlir::linalg;
|
|||
namespace {
|
||||
struct TestLinalgCodegenStrategy
|
||||
: public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-linalg-codegen-strategy"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg Codegen Strategy.";
|
||||
}
|
||||
TestLinalgCodegenStrategy() = default;
|
||||
TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
|
||||
|
||||
|
@ -227,8 +231,7 @@ void TestLinalgCodegenStrategy::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgCodegenStrategy() {
|
||||
PassRegistration<TestLinalgCodegenStrategy> testLinalgCodegenStrategyPass(
|
||||
"test-linalg-codegen-strategy", "Test Linalg Codegen Strategy.");
|
||||
PassRegistration<TestLinalgCodegenStrategy>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -40,6 +40,8 @@ static LinalgLoopDistributionOptions getDistributionOptions() {
|
|||
namespace {
|
||||
struct TestLinalgDistribution
|
||||
: public PassWrapper<TestLinalgDistribution, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-linalg-distribution"; }
|
||||
StringRef getDescription() const final { return "Test Linalg distribution."; }
|
||||
TestLinalgDistribution() = default;
|
||||
TestLinalgDistribution(const TestLinalgDistribution &pass) {}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -72,8 +74,7 @@ void TestLinalgDistribution::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgDistribution() {
|
||||
PassRegistration<TestLinalgDistribution> testTestLinalgDistributionPass(
|
||||
"test-linalg-distribution", "Test Linalg distribution.");
|
||||
PassRegistration<TestLinalgDistribution>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -51,6 +51,12 @@ struct TestLinalgElementwiseFusion
|
|||
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
||||
tensor::TensorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-elementwise-fusion-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg element wise operation fusion patterns";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
|
@ -73,6 +79,10 @@ struct TestPushExpandingReshape
|
|||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg reshape push patterns";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
|
@ -86,14 +96,11 @@ struct TestPushExpandingReshape
|
|||
|
||||
namespace test {
|
||||
void registerTestLinalgElementwiseFusion() {
|
||||
PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass(
|
||||
"test-linalg-elementwise-fusion-patterns",
|
||||
"Test Linalg element wise operation fusion patterns");
|
||||
PassRegistration<TestLinalgElementwiseFusion>();
|
||||
}
|
||||
|
||||
void registerTestPushExpandingReshape() {
|
||||
PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
|
||||
"test-linalg-push-reshape", "Test Linalg reshape push patterns");
|
||||
PassRegistration<TestPushExpandingReshape>();
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
|
|
|
@ -108,16 +108,15 @@ static void fillFusionPatterns(MLIRContext *context,
|
|||
}
|
||||
|
||||
namespace {
|
||||
template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops>
|
||||
template <LinalgTilingLoopType LoopType>
|
||||
struct TestLinalgFusionTransforms
|
||||
: public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
|
||||
TestLinalgFusionTransforms() = default;
|
||||
TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
||||
scf::SCFDialect, StandardOpsDialect>();
|
||||
}
|
||||
TestLinalgFusionTransforms() = default;
|
||||
TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
|
@ -130,6 +129,39 @@ struct TestLinalgFusionTransforms
|
|||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestLinalgFusionTransformsParallelLoops
|
||||
: public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-fusion-transform-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg fusion transformation patterns by applying them "
|
||||
"greedily.";
|
||||
}
|
||||
};
|
||||
|
||||
struct TestLinalgFusionTransformsLoops
|
||||
: public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-tensor-fusion-transform-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.";
|
||||
}
|
||||
};
|
||||
|
||||
struct TestLinalgFusionTransformsTiledLoops
|
||||
: public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-tiled-loop-fusion-transform-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.";
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
|
||||
|
@ -195,6 +227,10 @@ struct TestLinalgGreedyFusion
|
|||
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
||||
scf::SCFDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg fusion by applying a greedy test transformation.";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns =
|
||||
|
@ -218,6 +254,10 @@ struct TestLinalgGreedyFusion
|
|||
/// testing.
|
||||
struct TestLinalgTileAndFuseSequencePass
|
||||
: public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
|
||||
}
|
||||
TestLinalgTileAndFuseSequencePass() = default;
|
||||
TestLinalgTileAndFuseSequencePass(
|
||||
const TestLinalgTileAndFuseSequencePass &pass){};
|
||||
|
@ -261,39 +301,25 @@ struct TestLinalgTileAndFuseSequencePass
|
|||
op.erase();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgFusionTransforms() {
|
||||
PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass(
|
||||
"test-linalg-fusion-transform-patterns",
|
||||
"Test Linalg fusion transformation patterns by applying them greedily.");
|
||||
PassRegistration<TestLinalgFusionTransformsParallelLoops>();
|
||||
}
|
||||
void registerTestLinalgTensorFusionTransforms() {
|
||||
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>>
|
||||
testTensorFusionTransformsPass(
|
||||
"test-linalg-tensor-fusion-transform-patterns",
|
||||
"Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.");
|
||||
PassRegistration<TestLinalgFusionTransformsLoops>();
|
||||
}
|
||||
void registerTestLinalgTiledLoopFusionTransforms() {
|
||||
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
|
||||
testTiledLoopFusionTransformsPass(
|
||||
"test-linalg-tiled-loop-fusion-transform-patterns",
|
||||
"Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.");
|
||||
PassRegistration<TestLinalgFusionTransformsTiledLoops>();
|
||||
}
|
||||
void registerTestLinalgGreedyFusion() {
|
||||
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
|
||||
"test-linalg-greedy-fusion",
|
||||
"Test Linalg fusion by applying a greedy test transformation.");
|
||||
PassRegistration<TestLinalgGreedyFusion>();
|
||||
}
|
||||
void registerTestLinalgTileAndFuseSequencePass() {
|
||||
PassRegistration<TestLinalgTileAndFuseSequencePass>
|
||||
testTileAndFuseSequencePass(
|
||||
"test-linalg-tile-and-fuse",
|
||||
"Test Linalg tiling and fusion of a sequence of Linalg operations.");
|
||||
PassRegistration<TestLinalgTileAndFuseSequencePass>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
@ -26,6 +26,10 @@ struct TestLinalgHoisting
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-linalg-hoisting"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg hoisting functions.";
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
|
@ -46,9 +50,6 @@ void TestLinalgHoisting::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgHoisting() {
|
||||
PassRegistration<TestLinalgHoisting> testTestLinalgHoistingPass(
|
||||
"test-linalg-hoisting", "Test Linalg hoisting functions.");
|
||||
}
|
||||
void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -42,6 +42,12 @@ struct TestLinalgTransforms
|
|||
gpu::GPUDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-transform-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test Linalg transformation patterns by applying them greedily.";
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
|
@ -612,9 +618,7 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLinalgTransforms() {
|
||||
PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
|
||||
"test-linalg-transform-patterns",
|
||||
"Test Linalg transformation patterns by applying them greedily.");
|
||||
PassRegistration<TestLinalgTransforms>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,8 @@ namespace {
|
|||
struct TestExpandTanhPass
|
||||
: public PassWrapper<TestExpandTanhPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
StringRef getArgument() const final { return "test-expand-tanh"; }
|
||||
StringRef getDescription() const final { return "Test expanding tanh"; }
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -31,9 +33,6 @@ void TestExpandTanhPass::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestExpandTanhPass() {
|
||||
PassRegistration<TestExpandTanhPass> pass("test-expand-tanh",
|
||||
"Test expanding tanh");
|
||||
}
|
||||
void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -28,6 +28,12 @@ struct TestMathPolynomialApproximationPass
|
|||
registry
|
||||
.insert<vector::VectorDialect, math::MathDialect, LLVM::LLVMDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-math-polynomial-approximation";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test math polynomial approximations";
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -40,9 +46,7 @@ void TestMathPolynomialApproximationPass::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMathPolynomialApproximationPass() {
|
||||
PassRegistration<TestMathPolynomialApproximationPass> pass(
|
||||
"test-math-polynomial-approximation",
|
||||
"Test math polynomial approximations");
|
||||
PassRegistration<TestMathPolynomialApproximationPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -24,7 +24,9 @@ namespace {
|
|||
class TestSCFForUtilsPass
|
||||
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
|
||||
public:
|
||||
explicit TestSCFForUtilsPass() {}
|
||||
StringRef getArgument() const final { return "test-scf-for-utils"; }
|
||||
StringRef getDescription() const final { return "test scf.for utils"; }
|
||||
explicit TestSCFForUtilsPass() = default;
|
||||
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
|
@ -54,7 +56,9 @@ public:
|
|||
class TestSCFIfUtilsPass
|
||||
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
|
||||
public:
|
||||
explicit TestSCFIfUtilsPass() {}
|
||||
StringRef getArgument() const final { return "test-scf-if-utils"; }
|
||||
StringRef getDescription() const final { return "test scf.if utils"; }
|
||||
explicit TestSCFIfUtilsPass() = default;
|
||||
|
||||
void runOnFunction() override {
|
||||
int count = 0;
|
||||
|
@ -73,10 +77,8 @@ public:
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestSCFUtilsPass() {
|
||||
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
|
||||
"test scf.for utils");
|
||||
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
|
||||
"test scf.if utils");
|
||||
PassRegistration<TestSCFForUtilsPass>();
|
||||
PassRegistration<TestSCFIfUtilsPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -23,6 +23,10 @@ namespace {
|
|||
struct PrintOpAvailability
|
||||
: public PassWrapper<PrintOpAvailability, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
StringRef getArgument() const final { return "test-spirv-op-availability"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test SPIR-V op availability";
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -78,8 +82,7 @@ void PrintOpAvailability::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
void registerPrintOpAvailabilityPass() {
|
||||
PassRegistration<PrintOpAvailability> printOpAvailabilityPass(
|
||||
"test-spirv-op-availability", "Test SPIR-V op availability");
|
||||
PassRegistration<PrintOpAvailability>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -91,6 +94,10 @@ namespace {
|
|||
/// A pass for testing SPIR-V op availability.
|
||||
struct ConvertToTargetEnv
|
||||
: public PassWrapper<ConvertToTargetEnv, FunctionPass> {
|
||||
StringRef getArgument() const override { return "test-spirv-target-env"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Test SPIR-V target environment";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -225,7 +232,6 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
|
|||
|
||||
namespace mlir {
|
||||
void registerConvertToTargetEnvPass() {
|
||||
PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass(
|
||||
"test-spirv-target-env", "Test SPIR-V target environment");
|
||||
PassRegistration<ConvertToTargetEnv>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -24,6 +24,12 @@ class TestSpirvEntryPointABIPass
|
|||
: public PassWrapper<TestSpirvEntryPointABIPass,
|
||||
OperationPass<gpu::GPUModuleOp>> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-spirv-entry-point-abi"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Set the spv.entry_point_abi attribute on GPU kernel function "
|
||||
"within the "
|
||||
"module, intended for testing only";
|
||||
}
|
||||
TestSpirvEntryPointABIPass() = default;
|
||||
TestSpirvEntryPointABIPass(const TestSpirvEntryPointABIPass &) {}
|
||||
void runOnOperation() override;
|
||||
|
@ -56,9 +62,6 @@ void TestSpirvEntryPointABIPass::runOnOperation() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestSpirvEntryPointABIPass() {
|
||||
PassRegistration<TestSpirvEntryPointABIPass> registration(
|
||||
"test-spirv-entry-point-abi",
|
||||
"Set the spv.entry_point_abi attribute on GPU kernel function within the "
|
||||
"module, intended for testing only");
|
||||
PassRegistration<TestSpirvEntryPointABIPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,12 @@ public:
|
|||
TestGLSLCanonicalizationPass() = default;
|
||||
TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
|
||||
void runOnOperation() override;
|
||||
StringRef getArgument() const final {
|
||||
return "test-spirv-glsl-canonicalization";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Tests SPIR-V canonicalization patterns for GLSL extension.";
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -31,8 +37,6 @@ void TestGLSLCanonicalizationPass::runOnOperation() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestSpirvGLSLCanonicalizationPass() {
|
||||
PassRegistration<TestGLSLCanonicalizationPass> registration(
|
||||
"test-spirv-glsl-canonicalization",
|
||||
"Tests SPIR-V canonicalization patterns for GLSL extension.");
|
||||
PassRegistration<TestGLSLCanonicalizationPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ class TestModuleCombinerPass
|
|||
: public PassWrapper<TestModuleCombinerPass,
|
||||
OperationPass<mlir::ModuleOp>> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-spirv-module-combiner"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests SPIR-V module combiner library";
|
||||
}
|
||||
TestModuleCombinerPass() = default;
|
||||
TestModuleCombinerPass(const TestModuleCombinerPass &) {}
|
||||
void runOnOperation() override;
|
||||
|
@ -41,7 +45,6 @@ void TestModuleCombinerPass::runOnOperation() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTestSpirvModuleCombinerPass() {
|
||||
PassRegistration<TestModuleCombinerPass> registration(
|
||||
"test-spirv-module-combiner", "Tests SPIR-V module combiner library");
|
||||
PassRegistration<TestModuleCombinerPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ namespace {
|
|||
struct ReportShapeFnPass
|
||||
: public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
StringRef getArgument() const final { return "test-shape-function-report"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test pass to report associated shape functions";
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -82,8 +86,6 @@ void ReportShapeFnPass::runOnOperation() {
|
|||
|
||||
namespace mlir {
|
||||
void registerShapeFunctionTestPasses() {
|
||||
PassRegistration<ReportShapeFnPass>(
|
||||
"test-shape-function-report",
|
||||
"Test pass to report associated shape functions");
|
||||
PassRegistration<ReportShapeFnPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct TestComposeSubViewPass
|
||||
: public PassWrapper<TestComposeSubViewPass, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-compose-subview"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test combining composed subviews";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override;
|
||||
};
|
||||
|
@ -39,8 +43,7 @@ void TestComposeSubViewPass::runOnFunction() {
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestComposeSubView() {
|
||||
PassRegistration<TestComposeSubViewPass> pass(
|
||||
"test-compose-subview", "Test combining composed subviews");
|
||||
PassRegistration<TestComposeSubViewPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -27,6 +27,12 @@ struct TestDecomposeCallGraphTypes
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<test::TestDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-decompose-call-graph-types";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Decomposes types at call graph boundaries.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
@ -87,9 +93,7 @@ struct TestDecomposeCallGraphTypes
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDecomposeCallGraphTypes() {
|
||||
PassRegistration<TestDecomposeCallGraphTypes> pass(
|
||||
"test-decompose-call-graph-types",
|
||||
"Decomposes types at call graph boundaries.");
|
||||
PassRegistration<TestDecomposeCallGraphTypes>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -95,6 +95,8 @@ public:
|
|||
};
|
||||
|
||||
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-patterns"; }
|
||||
StringRef getDescription() const final { return "Run test dialect patterns"; }
|
||||
void runOnFunction() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateWithGenerated(patterns);
|
||||
|
@ -159,6 +161,8 @@ struct TestReturnTypeDriver
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-return-type"; }
|
||||
StringRef getDescription() const final { return "Run return type functions"; }
|
||||
|
||||
void runOnFunction() override {
|
||||
if (getFunction().getName() == "testCreateFunctions") {
|
||||
|
@ -194,6 +198,10 @@ struct TestReturnTypeDriver
|
|||
namespace {
|
||||
struct TestDerivedAttributeDriver
|
||||
: public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-derived-attr"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Run test derived attributes";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -585,6 +593,10 @@ struct TestTypeConverter : public TypeConverter {
|
|||
|
||||
struct TestLegalizePatternDriver
|
||||
: public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-legalize-patterns"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Run test dialect legalization patterns";
|
||||
}
|
||||
/// The mode of conversion to use with the driver.
|
||||
enum class ConversionMode { Analysis, Full, Partial };
|
||||
|
||||
|
@ -733,6 +745,10 @@ struct OneVResOneVOperandOp1Converter
|
|||
|
||||
struct TestRemappedValue
|
||||
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-remapped-value"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test public remapped value mechanism in ConversionPatternRewriter";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
|
||||
|
@ -776,6 +792,12 @@ struct RemoveTestDialectOps : public RewritePattern {
|
|||
|
||||
struct TestUnknownRootOpDriver
|
||||
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-legalize-unknown-root-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test public remapped value mechanism in ConversionPatternRewriter";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<RemoveTestDialectOps>(&getContext());
|
||||
|
@ -857,6 +879,12 @@ struct TestTypeConversionDriver
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TestDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-legalize-type-conversion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test various type conversion functionalities in DialectConversion";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
// Initialize the type converter.
|
||||
|
@ -999,6 +1027,10 @@ struct TestMergeSingleBlockOps
|
|||
struct TestMergeBlocksPatternDriver
|
||||
: public PassWrapper<TestMergeBlocksPatternDriver,
|
||||
OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-merge-blocks"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Merging operation in ConversionPatternRewriter";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
@ -1066,6 +1098,12 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
|
|||
struct TestSelectiveReplacementPatternDriver
|
||||
: public PassWrapper<TestSelectiveReplacementPatternDriver,
|
||||
OperationPass<>> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-pattern-selective-replacement";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test selective replacement in the PatternRewriter";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
@ -1083,39 +1121,24 @@ struct TestSelectiveReplacementPatternDriver
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerPatternsTestPass() {
|
||||
PassRegistration<TestReturnTypeDriver>("test-return-type",
|
||||
"Run return type functions");
|
||||
PassRegistration<TestReturnTypeDriver>();
|
||||
|
||||
PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
|
||||
"Run test derived attributes");
|
||||
PassRegistration<TestDerivedAttributeDriver>();
|
||||
|
||||
PassRegistration<TestPatternDriver>("test-patterns",
|
||||
"Run test dialect patterns");
|
||||
PassRegistration<TestPatternDriver>();
|
||||
|
||||
PassRegistration<TestLegalizePatternDriver>(
|
||||
"test-legalize-patterns", "Run test dialect legalization patterns", [] {
|
||||
return std::make_unique<TestLegalizePatternDriver>(
|
||||
legalizerConversionMode);
|
||||
});
|
||||
PassRegistration<TestLegalizePatternDriver>([] {
|
||||
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
|
||||
});
|
||||
|
||||
PassRegistration<TestRemappedValue>(
|
||||
"test-remapped-value",
|
||||
"Test public remapped value mechanism in ConversionPatternRewriter");
|
||||
PassRegistration<TestRemappedValue>();
|
||||
|
||||
PassRegistration<TestUnknownRootOpDriver>(
|
||||
"test-legalize-unknown-root-patterns",
|
||||
"Test public remapped value mechanism in ConversionPatternRewriter");
|
||||
PassRegistration<TestUnknownRootOpDriver>();
|
||||
|
||||
PassRegistration<TestTypeConversionDriver>(
|
||||
"test-legalize-type-conversion",
|
||||
"Test various type conversion functionalities in DialectConversion");
|
||||
PassRegistration<TestTypeConversionDriver>();
|
||||
|
||||
PassRegistration<TestMergeBlocksPatternDriver>{
|
||||
"test-merge-blocks",
|
||||
"Test Merging operation in ConversionPatternRewriter"};
|
||||
PassRegistration<TestSelectiveReplacementPatternDriver>{
|
||||
"test-pattern-selective-replacement",
|
||||
"Test selective replacement in the PatternRewriter"};
|
||||
PassRegistration<TestMergeBlocksPatternDriver>();
|
||||
PassRegistration<TestSelectiveReplacementPatternDriver>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -32,6 +32,8 @@ OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
|
|||
|
||||
namespace {
|
||||
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-trait-folder"; }
|
||||
StringRef getDescription() const final { return "Run trait folding"; }
|
||||
void runOnFunction() override {
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(),
|
||||
RewritePatternSet(&getContext()));
|
||||
|
@ -40,7 +42,5 @@ struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
|
|||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestTraitsPass() {
|
||||
PassRegistration<TestTraitFolder>("test-trait-folder", "Run trait folding");
|
||||
}
|
||||
void registerTestTraitsPass() { PassRegistration<TestTraitFolder>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -179,6 +179,10 @@ namespace {
|
|||
|
||||
struct TosaTestQuantUtilAPI
|
||||
: public PassWrapper<TosaTestQuantUtilAPI, FunctionPass> {
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -196,7 +200,6 @@ void TosaTestQuantUtilAPI::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
void registerTosaTestQuantUtilAPIPass() {
|
||||
PassRegistration<TosaTestQuantUtilAPI>(
|
||||
PASS_NAME, "TOSA Test: Exercise the APIs in QuantUtils.cpp.");
|
||||
PassRegistration<TosaTestQuantUtilAPI>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -27,6 +27,12 @@ struct TestVectorToVectorConversion
|
|||
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
|
||||
TestVectorToVectorConversion() = default;
|
||||
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-to-vector-conversion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns between ops in the vector dialect";
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
|
@ -69,6 +75,13 @@ private:
|
|||
|
||||
struct TestVectorSlicesConversion
|
||||
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-slices-conversion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns that lower slices ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorSlicesLoweringPatterns(patterns);
|
||||
|
@ -78,6 +91,13 @@ struct TestVectorSlicesConversion
|
|||
|
||||
struct TestVectorContractionConversion
|
||||
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-contraction-conversion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns that lower contract ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorContractionConversion() = default;
|
||||
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
||||
}
|
||||
|
@ -146,6 +166,13 @@ struct TestVectorContractionConversion
|
|||
|
||||
struct TestVectorUnrollingPatterns
|
||||
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-unrolling-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to unroll contract ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorUnrollingPatterns() = default;
|
||||
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
|
||||
void runOnFunction() override {
|
||||
|
@ -199,6 +226,13 @@ struct TestVectorUnrollingPatterns
|
|||
|
||||
struct TestVectorDistributePatterns
|
||||
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-distribute-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to distribute vector ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorDistributePatterns() = default;
|
||||
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -249,6 +283,10 @@ struct TestVectorDistributePatterns
|
|||
|
||||
struct TestVectorToLoopPatterns
|
||||
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-vector-to-forloop"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to break up a vector op into a for loop";
|
||||
}
|
||||
TestVectorToLoopPatterns() = default;
|
||||
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -312,6 +350,13 @@ struct TestVectorTransferUnrollingPatterns
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-transfer-unrolling-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to unroll transfer ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
|
@ -332,6 +377,13 @@ struct TestVectorTransferUnrollingPatterns
|
|||
struct TestVectorTransferFullPartialSplitPatterns
|
||||
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
||||
FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-transfer-full-partial-split";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to split "
|
||||
"transfer ops via scf.if + linalg ops";
|
||||
}
|
||||
TestVectorTransferFullPartialSplitPatterns() = default;
|
||||
TestVectorTransferFullPartialSplitPatterns(
|
||||
const TestVectorTransferFullPartialSplitPatterns &pass) {}
|
||||
|
@ -361,6 +413,10 @@ struct TestVectorTransferFullPartialSplitPatterns
|
|||
|
||||
struct TestVectorTransferOpt
|
||||
: public PassWrapper<TestVectorTransferOpt, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test optimization transformations for transfer ops";
|
||||
}
|
||||
void runOnFunction() override { transferOpflowOpt(getFunction()); }
|
||||
};
|
||||
|
||||
|
@ -369,6 +425,12 @@ struct TestVectorTransferLoweringPatterns
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-transfer-lowering-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to lower transfer ops to other vector ops";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorTransferLoweringPatterns(patterns);
|
||||
|
@ -382,6 +444,13 @@ struct TestVectorMultiReductionLoweringPatterns
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-multi-reduction-lowering-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to lower vector.multi_reduction to other "
|
||||
"vector ops";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorMultiReductionLoweringPatterns(patterns);
|
||||
|
@ -394,53 +463,27 @@ struct TestVectorMultiReductionLoweringPatterns
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestVectorConversions() {
|
||||
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
|
||||
"test-vector-to-vector-conversion",
|
||||
"Test conversion patterns between ops in the vector dialect");
|
||||
PassRegistration<TestVectorToVectorConversion>();
|
||||
|
||||
PassRegistration<TestVectorSlicesConversion> slicesPass(
|
||||
"test-vector-slices-conversion",
|
||||
"Test conversion patterns that lower slices ops in the vector dialect");
|
||||
PassRegistration<TestVectorSlicesConversion>();
|
||||
|
||||
PassRegistration<TestVectorContractionConversion> contractionPass(
|
||||
"test-vector-contraction-conversion",
|
||||
"Test conversion patterns that lower contract ops in the vector dialect");
|
||||
PassRegistration<TestVectorContractionConversion>();
|
||||
|
||||
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
|
||||
"test-vector-unrolling-patterns",
|
||||
"Test conversion patterns to unroll contract ops in the vector dialect");
|
||||
PassRegistration<TestVectorUnrollingPatterns>();
|
||||
|
||||
PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
|
||||
"test-vector-transfer-unrolling-patterns",
|
||||
"Test conversion patterns to unroll transfer ops in the vector dialect");
|
||||
PassRegistration<TestVectorTransferUnrollingPatterns>();
|
||||
|
||||
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
|
||||
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
|
||||
"Test conversion patterns to split "
|
||||
"transfer ops via scf.if + linalg ops");
|
||||
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
|
||||
|
||||
PassRegistration<TestVectorDistributePatterns> distributePass(
|
||||
"test-vector-distribute-patterns",
|
||||
"Test conversion patterns to distribute vector ops in the vector "
|
||||
"dialect");
|
||||
PassRegistration<TestVectorDistributePatterns>();
|
||||
|
||||
PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
|
||||
"test-vector-to-forloop",
|
||||
"Test conversion patterns to break up a vector op into a for loop");
|
||||
PassRegistration<TestVectorToLoopPatterns>();
|
||||
|
||||
PassRegistration<TestVectorTransferOpt> transferOpOpt(
|
||||
"test-vector-transferop-opt",
|
||||
"Test optimization transformations for transfer ops");
|
||||
PassRegistration<TestVectorTransferOpt>();
|
||||
|
||||
PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
|
||||
"test-vector-transfer-lowering-patterns",
|
||||
"Test conversion patterns to lower transfer ops to other vector ops");
|
||||
PassRegistration<TestVectorTransferLoweringPatterns>();
|
||||
|
||||
PassRegistration<TestVectorMultiReductionLoweringPatterns>
|
||||
multiDimReductionOpLoweringPass(
|
||||
"test-vector-multi-reduction-lowering-patterns",
|
||||
"Test conversion patterns to lower vector.multi_reduction to other "
|
||||
"vector ops");
|
||||
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -91,6 +91,10 @@ private:
|
|||
};
|
||||
|
||||
struct TestDominancePass : public PassWrapper<TestDominancePass, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-print-dominance"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Print the dominance information for multiple regions.";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
llvm::errs() << "Testing : " << getFunction().getName() << "\n";
|
||||
|
@ -120,10 +124,6 @@ struct TestDominancePass : public PassWrapper<TestDominancePass, FunctionPass> {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDominancePass() {
|
||||
PassRegistration<TestDominancePass>(
|
||||
"test-print-dominance",
|
||||
"Print the dominance information for multiple regions.");
|
||||
}
|
||||
void registerTestDominancePass() { PassRegistration<TestDominancePass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,6 +15,8 @@ namespace {
|
|||
/// This is a test pass for verifying FuncOp's eraseArgument method.
|
||||
struct TestFuncEraseArg
|
||||
: public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-func-erase-arg"; }
|
||||
StringRef getDescription() const final { return "Test erasing func args."; }
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
|
@ -39,21 +41,28 @@ struct TestFuncEraseArg
|
|||
/// This is a test pass for verifying FuncOp's eraseResult method.
|
||||
struct TestFuncEraseResult
|
||||
: public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-func-erase-result"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test erasing func results.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
SmallVector<unsigned, 4> indicesToErase;
|
||||
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
|
||||
if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
|
||||
// Push back twice to test that duplicate indices are handled
|
||||
// correctly.
|
||||
if (func.getResultAttr(resultIndex, "test.erase_this_"
|
||||
"result")) {
|
||||
// Push back twice to test
|
||||
// that duplicate indices
|
||||
// are handled correctly.
|
||||
indicesToErase.push_back(resultIndex);
|
||||
indicesToErase.push_back(resultIndex);
|
||||
}
|
||||
}
|
||||
// Reverse the order to test that unsorted index lists are handled
|
||||
// correctly.
|
||||
// Reverse the order to test
|
||||
// that unsorted index lists are
|
||||
// handled correctly.
|
||||
std::reverse(indicesToErase.begin(), indicesToErase.end());
|
||||
func.eraseResults(indicesToErase);
|
||||
}
|
||||
|
@ -63,6 +72,8 @@ struct TestFuncEraseResult
|
|||
/// This is a test pass for verifying FuncOp's setType method.
|
||||
struct TestFuncSetType
|
||||
: public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-func-set-type"; }
|
||||
StringRef getDescription() const final { return "Test FuncOp::setType."; }
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
SymbolTable symbolTable(module);
|
||||
|
@ -79,13 +90,10 @@ struct TestFuncSetType
|
|||
|
||||
namespace mlir {
|
||||
void registerTestFunc() {
|
||||
PassRegistration<TestFuncEraseArg>("test-func-erase-arg",
|
||||
"Test erasing func args.");
|
||||
PassRegistration<TestFuncEraseArg>();
|
||||
|
||||
PassRegistration<TestFuncEraseResult>("test-func-erase-result",
|
||||
"Test erasing func results.");
|
||||
PassRegistration<TestFuncEraseResult>();
|
||||
|
||||
PassRegistration<TestFuncSetType>("test-func-set-type",
|
||||
"Test FuncOp::setType.");
|
||||
PassRegistration<TestFuncSetType>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,6 +17,10 @@ namespace {
|
|||
/// application.
|
||||
struct TestTypeInterfaces
|
||||
: public PassWrapper<TestTypeInterfaces, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-type-interfaces"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test type interface support.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
getOperation().walk([](Operation *op) {
|
||||
for (Type type : op->getResultTypes()) {
|
||||
|
@ -40,9 +44,6 @@ struct TestTypeInterfaces
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestInterfaces() {
|
||||
PassRegistration<TestTypeInterfaces> pass("test-type-interfaces",
|
||||
"Test type interface support.");
|
||||
}
|
||||
void registerTestInterfaces() { PassRegistration<TestTypeInterfaces>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,6 +17,10 @@ namespace {
|
|||
/// This is a test pass for verifying matchers.
|
||||
struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
StringRef getArgument() const final { return "test-matchers"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test C++ pattern matchers.";
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -148,7 +152,5 @@ void TestMatchers::runOnFunction() {
|
|||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestMatchers() {
|
||||
PassRegistration<TestMatchers>("test-matchers", "Test C++ pattern matchers.");
|
||||
}
|
||||
void registerTestMatchers() { PassRegistration<TestMatchers>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -18,6 +18,10 @@ namespace {
|
|||
/// locations.
|
||||
struct TestOpaqueLoc
|
||||
: public PassWrapper<TestOpaqueLoc, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-opaque-loc"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Changes all leaf locations to opaque locations";
|
||||
}
|
||||
|
||||
/// A simple structure which is used for testing as an underlying location in
|
||||
/// OpaqueLoc.
|
||||
|
@ -82,9 +86,6 @@ struct TestOpaqueLoc
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestOpaqueLoc() {
|
||||
PassRegistration<TestOpaqueLoc> pass(
|
||||
"test-opaque-loc", "Changes all leaf locations to opaque locations");
|
||||
}
|
||||
void registerTestOpaqueLoc() { PassRegistration<TestOpaqueLoc>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -16,6 +16,8 @@ namespace {
|
|||
/// This pass illustrates the IR def-use chains through printing.
|
||||
struct TestPrintDefUsePass
|
||||
: public PassWrapper<TestPrintDefUsePass, OperationPass<>> {
|
||||
StringRef getArgument() const final { return "test-print-defuse"; }
|
||||
StringRef getDescription() const final { return "Test various printing."; }
|
||||
void runOnOperation() override {
|
||||
// Recursively traverse the IR nested under the current operation and print
|
||||
// every single operation and their operands and users.
|
||||
|
@ -64,8 +66,5 @@ struct TestPrintDefUsePass
|
|||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestPrintDefUsePass() {
|
||||
PassRegistration<TestPrintDefUsePass>("test-print-defuse",
|
||||
"Test various printing.");
|
||||
}
|
||||
void registerTestPrintDefUsePass() { PassRegistration<TestPrintDefUsePass>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -16,6 +16,8 @@ namespace {
|
|||
/// This pass illustrates the IR nesting through printing.
|
||||
struct TestPrintNestingPass
|
||||
: public PassWrapper<TestPrintNestingPass, OperationPass<>> {
|
||||
StringRef getArgument() const final { return "test-print-nesting"; }
|
||||
StringRef getDescription() const final { return "Test various printing."; }
|
||||
// Entry point for the pass.
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
|
@ -90,7 +92,6 @@ struct TestPrintNestingPass
|
|||
|
||||
namespace mlir {
|
||||
void registerTestPrintNestingPass() {
|
||||
PassRegistration<TestPrintNestingPass>("test-print-nesting",
|
||||
"Test various printing.");
|
||||
PassRegistration<TestPrintNestingPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -14,6 +14,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct SideEffectsPass
|
||||
: public PassWrapper<SideEffectsPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-side-effects"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test side effects interfaces";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
|
@ -68,8 +72,5 @@ struct SideEffectsPass
|
|||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerSideEffectTestPasses() {
|
||||
PassRegistration<SideEffectsPass>("test-side-effects",
|
||||
"Test side effects interfaces");
|
||||
}
|
||||
void registerSideEffectTestPasses() { PassRegistration<SideEffectsPass>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -47,6 +47,10 @@ namespace {
|
|||
/// Pass to test slice generated from slice analysis.
|
||||
struct SliceAnalysisTestPass
|
||||
: public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "slice-analysis-test"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test Slice analysis functionality.";
|
||||
}
|
||||
void runOnOperation() override;
|
||||
SliceAnalysisTestPass() = default;
|
||||
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
|
||||
|
@ -74,7 +78,6 @@ void SliceAnalysisTestPass::runOnOperation() {
|
|||
|
||||
namespace mlir {
|
||||
void registerSliceAnalysisTestPass() {
|
||||
PassRegistration<SliceAnalysisTestPass> pass(
|
||||
"slice-analysis-test", "Test Slice analysis functionality.");
|
||||
PassRegistration<SliceAnalysisTestPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,6 +17,10 @@ namespace {
|
|||
/// provided by the symbol table along with erasing from the symbol table.
|
||||
struct SymbolUsesPass
|
||||
: public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-symbol-uses"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test detection of symbol uses";
|
||||
}
|
||||
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
|
||||
SmallVectorImpl<FuncOp> &deadFunctions) {
|
||||
// Test computing uses on a non symboltable op.
|
||||
|
@ -89,6 +93,10 @@ struct SymbolUsesPass
|
|||
/// functionality provided by the symbol table.
|
||||
struct SymbolReplacementPass
|
||||
: public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-symbol-rauw"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test replacement of symbol uses";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
|
@ -111,10 +119,8 @@ struct SymbolReplacementPass
|
|||
|
||||
namespace mlir {
|
||||
void registerSymbolTestPasses() {
|
||||
PassRegistration<SymbolUsesPass>("test-symbol-uses",
|
||||
"Test detection of symbol uses");
|
||||
PassRegistration<SymbolUsesPass>();
|
||||
|
||||
PassRegistration<SymbolReplacementPass>("test-symbol-rauw",
|
||||
"Test replacement of symbol uses");
|
||||
PassRegistration<SymbolReplacementPass>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -18,6 +18,10 @@ struct TestRecursiveTypesPass
|
|||
: public PassWrapper<TestRecursiveTypesPass, FunctionPass> {
|
||||
LogicalResult createIRWithTypes();
|
||||
|
||||
StringRef getArgument() const final { return "test-recursive-types"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test support for recursive types";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
|
@ -73,8 +77,7 @@ namespace mlir {
|
|||
namespace test {
|
||||
|
||||
void registerTestRecursiveTypesPass() {
|
||||
PassRegistration<TestRecursiveTypesPass> reg(
|
||||
"test-recursive-types", "Test support for recursive types");
|
||||
PassRegistration<TestRecursiveTypesPass>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
@ -152,6 +152,8 @@ namespace {
|
|||
/// This pass exercises the different configurations of the IR visitors.
|
||||
struct TestIRVisitorsPass
|
||||
: public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
|
||||
StringRef getArgument() const final { return "test-ir-visitors"; }
|
||||
StringRef getDescription() const final { return "Test various visitors."; }
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
testPureCallbacks(op);
|
||||
|
@ -163,9 +165,6 @@ struct TestIRVisitorsPass
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestIRVisitorsPass() {
|
||||
PassRegistration<TestIRVisitorsPass>("test-ir-visitors",
|
||||
"Test various visitors.");
|
||||
}
|
||||
void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -20,6 +20,11 @@ namespace {
|
|||
class TestDynamicPipelinePass
|
||||
: public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-dynamic-pipeline"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests the dynamic pipeline feature by applying "
|
||||
"a pipeline on a selected set of functions";
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
OpPassManager pm(ModuleOp::getOperationName(),
|
||||
OpPassManager::Nesting::Implicit);
|
||||
|
@ -106,9 +111,7 @@ public:
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDynamicPipelinePass() {
|
||||
PassRegistration<TestDynamicPipelinePass>(
|
||||
"test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
|
||||
"a pipeline on a selected set of functions");
|
||||
PassRegistration<TestDynamicPipelinePass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,10 +17,16 @@ struct TestModulePass
|
|||
: public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() final {}
|
||||
StringRef getArgument() const final { return "test-module-pass"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test a module pass in the pass manager";
|
||||
}
|
||||
};
|
||||
struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
|
||||
void runOnFunction() final {}
|
||||
StringRef getArgument() const final { return "test-function-pass"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test a function pass in the pass manager";
|
||||
}
|
||||
};
|
||||
class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
|
||||
public:
|
||||
|
@ -44,6 +50,9 @@ public:
|
|||
|
||||
void runOnFunction() final {}
|
||||
StringRef getArgument() const final { return "test-options-pass"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test options parsing capabilities";
|
||||
}
|
||||
|
||||
ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
|
||||
llvm::cl::desc("Example list option")};
|
||||
|
@ -60,12 +69,19 @@ class TestCrashRecoveryPass
|
|||
: public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
|
||||
void runOnOperation() final { abort(); }
|
||||
StringRef getArgument() const final { return "test-pass-crash"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test a pass in the pass manager that always crashes";
|
||||
}
|
||||
};
|
||||
|
||||
/// A test pass that always fails to enable testing the failure recovery
|
||||
/// mechanisms of the pass manager.
|
||||
class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
|
||||
void runOnOperation() final { signalPassFailure(); }
|
||||
StringRef getArgument() const final { return "test-pass-failure"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test a pass in the pass manager that always fails";
|
||||
}
|
||||
};
|
||||
|
||||
/// A test pass that contains a statistic.
|
||||
|
@ -73,6 +89,8 @@ struct TestStatisticPass
|
|||
: public PassWrapper<TestStatisticPass, OperationPass<>> {
|
||||
TestStatisticPass() = default;
|
||||
TestStatisticPass(const TestStatisticPass &) {}
|
||||
StringRef getArgument() const final { return "test-stats-pass"; }
|
||||
StringRef getDescription() const final { return "Test pass statistics"; }
|
||||
|
||||
Statistic opCount{this, "num-ops", "Number of operations counted"};
|
||||
|
||||
|
@ -102,22 +120,16 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
|
|||
|
||||
namespace mlir {
|
||||
void registerPassManagerTestPass() {
|
||||
PassRegistration<TestOptionsPass>("test-options-pass",
|
||||
"Test options parsing capabilities");
|
||||
PassRegistration<TestOptionsPass>();
|
||||
|
||||
PassRegistration<TestModulePass>("test-module-pass",
|
||||
"Test a module pass in the pass manager");
|
||||
PassRegistration<TestModulePass>();
|
||||
|
||||
PassRegistration<TestFunctionPass>(
|
||||
"test-function-pass", "Test a function pass in the pass manager");
|
||||
PassRegistration<TestFunctionPass>();
|
||||
|
||||
PassRegistration<TestCrashRecoveryPass>(
|
||||
"test-pass-crash", "Test a pass in the pass manager that always crashes");
|
||||
PassRegistration<TestFailurePass>(
|
||||
"test-pass-failure", "Test a pass in the pass manager that always fails");
|
||||
PassRegistration<TestCrashRecoveryPass>();
|
||||
PassRegistration<TestFailurePass>();
|
||||
|
||||
PassRegistration<TestStatisticPass> unusedStatP("test-stats-pass",
|
||||
"Test pass statistics");
|
||||
PassRegistration<TestStatisticPass>();
|
||||
|
||||
PassPipelineRegistration<>("test-pm-nested-pipeline",
|
||||
"Test a nested pipeline in the pass manager",
|
||||
|
|
|
@ -26,6 +26,10 @@ namespace {
|
|||
/// "crashOp" in the input MLIR file and crashes the mlir-opt tool if the
|
||||
/// operation is found.
|
||||
struct TestReducer : public PassWrapper<TestReducer, FunctionPass> {
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests MLIR Reduce tool by generating failures";
|
||||
}
|
||||
TestReducer() = default;
|
||||
TestReducer(const TestReducer &pass){};
|
||||
void runOnFunction() override;
|
||||
|
@ -47,8 +51,5 @@ void TestReducer::runOnFunction() {
|
|||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestReducer() {
|
||||
PassRegistration<TestReducer>(
|
||||
PASS_NAME, "Tests MLIR Reduce tool by generating failures");
|
||||
}
|
||||
void registerTestReducer() { PassRegistration<TestReducer>(); }
|
||||
} // namespace mlir
|
||||
|
|
|
@ -71,6 +71,10 @@ static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
|||
namespace {
|
||||
struct TestPDLByteCodePass
|
||||
: public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
|
||||
StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test PDL ByteCode functionality";
|
||||
}
|
||||
void runOnOperation() final {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
|
@ -107,9 +111,6 @@ struct TestPDLByteCodePass
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestPDLByteCodePass() {
|
||||
PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
|
||||
"Test PDL ByteCode functionality");
|
||||
}
|
||||
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -16,6 +16,10 @@ using namespace mlir;
|
|||
namespace {
|
||||
/// Simple constant folding pass.
|
||||
struct TestConstantFold : public PassWrapper<TestConstantFold, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-constant-fold"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test operation constant folding";
|
||||
}
|
||||
// All constants in the function post folding.
|
||||
SmallVector<Operation *, 8> existingConstants;
|
||||
|
||||
|
@ -62,9 +66,6 @@ void TestConstantFold::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestConstantFold() {
|
||||
PassRegistration<TestConstantFold>("test-constant-fold",
|
||||
"Test operation constant folding");
|
||||
}
|
||||
void registerTestConstantFold() { PassRegistration<TestConstantFold>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -26,6 +26,11 @@ using namespace mlir::test;
|
|||
|
||||
namespace {
|
||||
struct Inliner : public PassWrapper<Inliner, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-inline"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test inlining region calls";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
auto function = getFunction();
|
||||
|
||||
|
@ -63,8 +68,6 @@ struct Inliner : public PassWrapper<Inliner, FunctionPass> {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerInliner() {
|
||||
PassRegistration<Inliner>("test-inline", "Test inlining region calls");
|
||||
}
|
||||
void registerInliner() { PassRegistration<Inliner>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -42,6 +42,10 @@ static llvm::cl::opt<bool> clTestLoopFusionTransformation(
|
|||
namespace {
|
||||
|
||||
struct TestLoopFusion : public PassWrapper<TestLoopFusion, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-loop-fusion"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests loop fusion utility functions.";
|
||||
}
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -198,9 +202,6 @@ void TestLoopFusion::runOnFunction() {
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLoopFusion() {
|
||||
PassRegistration<TestLoopFusion>("test-loop-fusion",
|
||||
"Tests loop fusion utility functions.");
|
||||
}
|
||||
void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -26,6 +26,12 @@ namespace {
|
|||
class TestLoopMappingPass
|
||||
: public PassWrapper<TestLoopMappingPass, FunctionPass> {
|
||||
public:
|
||||
StringRef getArgument() const final {
|
||||
return "test-mapping-to-processing-elements";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "test mapping a single loop on a virtual processor grid";
|
||||
}
|
||||
explicit TestLoopMappingPass() {}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -58,10 +64,6 @@ public:
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLoopMappingPass() {
|
||||
PassRegistration<TestLoopMappingPass>(
|
||||
"test-mapping-to-processing-elements",
|
||||
"test mapping a single loop on a virtual processor grid");
|
||||
}
|
||||
void registerTestLoopMappingPass() { PassRegistration<TestLoopMappingPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -25,6 +25,14 @@ namespace {
|
|||
class SimpleParametricLoopTilingPass
|
||||
: public PassWrapper<SimpleParametricLoopTilingPass, FunctionPass> {
|
||||
public:
|
||||
StringRef getArgument() const final {
|
||||
return "test-extract-fixed-outer-loops";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "test application of parametric tiling to the outer loops so that "
|
||||
"the "
|
||||
"ranges of outer loops become static";
|
||||
}
|
||||
SimpleParametricLoopTilingPass() = default;
|
||||
SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {}
|
||||
explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) {
|
||||
|
@ -51,10 +59,7 @@ public:
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerSimpleParametricTilingPass() {
|
||||
PassRegistration<SimpleParametricLoopTilingPass>(
|
||||
"test-extract-fixed-outer-loops",
|
||||
"test application of parametric tiling to the outer loops so that the "
|
||||
"ranges of outer loops become static");
|
||||
PassRegistration<SimpleParametricLoopTilingPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -33,6 +33,10 @@ static unsigned getNestingDepth(Operation *op) {
|
|||
class TestLoopUnrollingPass
|
||||
: public PassWrapper<TestLoopUnrollingPass, FunctionPass> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-loop-unrolling"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Tests loop unrolling transformation";
|
||||
}
|
||||
TestLoopUnrollingPass() = default;
|
||||
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
|
||||
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
|
||||
|
@ -65,8 +69,7 @@ public:
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestLoopUnrollingPass() {
|
||||
PassRegistration<TestLoopUnrollingPass>(
|
||||
"test-loop-unrolling", "Tests loop unrolling transformation");
|
||||
PassRegistration<TestLoopUnrollingPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue