Add static pass registration

Add static pass registration and change mlir-opt to use it. Future work is needed to refactor the registration for PassManager usage.

Change build targets to alwayslink to enforce registration.

PiperOrigin-RevId: 220390178
This commit is contained in:
Jacques Pienaar 2018-11-06 18:34:18 -08:00 committed by jpienaar
parent 559e816f3f
commit 6f0fb22723
18 changed files with 256 additions and 119 deletions

View File

@ -18,7 +18,10 @@
#ifndef MLIR_PASS_H
#define MLIR_PASS_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Compiler.h"
#include <functional>
namespace mlir {
class Function;
@ -81,6 +84,65 @@ public:
virtual PassResult runOnModule(Module *m) override;
};
using PassAllocatorFunction = std::function<Pass *()>;
/// Structure to group information about a pass (argument to invoke via
/// mlir-opt, description, pass allocator and unique ID).
class PassInfo {
public:
/// PassInfo constructor should not be invoked directly, instead use
/// PassRegistration or registerPass.
PassInfo(StringRef arg, StringRef description, const void *passID,
PassAllocatorFunction allocator)
: arg(arg), description(description), allocator(allocator),
passID(passID){};
/// Returns an allocated instance of this pass.
Pass *createPass() const {
assert(allocator &&
"Cannot call createPass on PassInfo without default allocator");
return allocator();
}
/// Returns the command line option that may be passed to 'mlir-opt' that will
/// cause this pass to run or null if there is no such argument.
StringRef getPassArgument() const { return arg; }
/// Returns a description for the pass, this never returns null.
StringRef getPassDescription() const { return description; }
private:
// The argument with which to invoke the pass via mlir-opt.
StringRef arg;
// Description of the pass.
StringRef description;
// Allocator to construct an instance of this pass.
PassAllocatorFunction allocator;
// Unique identifier for pass.
const void *passID;
};
/// Register a specific dialect creation function with the system, typically
/// used through the PassRegistration template.
void registerPass(StringRef arg, StringRef description, const void *passID,
const PassAllocatorFunction &function);
/// PassRegistration provides a global initializer that registers a Pass
/// allocation routine.
///
/// Usage:
///
/// // At namespace scope.
/// static PassRegistration<MyPass> Unused("unused", "Unused pass");
template <typename ConcretePass> struct PassRegistration {
PassRegistration(StringRef arg, StringRef description) {
registerPass(arg, description, &ConcretePass::passID,
[&]() { return new ConcretePass(); });
}
};
} // end namespace mlir
#endif // MLIR_PASS_H

View File

@ -0,0 +1,40 @@
//===- PassNameParser.h - Base classes for compiler passes ------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// The PassNameParser class adds all passes linked in to the system that are
// creatable to the tool.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_PASSNAMEPARSER_H_
#define MLIR_SUPPORT_PASSNAMEPARSER_H_
#include "llvm/Support/CommandLine.h"
namespace mlir {
class PassInfo;
/// Adds command line option for each registered pass.
struct PassNameParser : public llvm::cl::parser<const PassInfo *> {
PassNameParser(llvm::cl::Option &opt);
void printOptionInfo(const llvm::cl::Option &O,
size_t GlobalWidth) const override;
};
} // end namespace mlir
#endif // MLIR_SUPPORT_PASSNAMEPARSER_H_

View File

@ -45,10 +45,14 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
void visitOperationStmt(OperationStmt *opStmt);
static char passID;
};
} // end anonymous namespace
char MemRefBoundCheck::passID = 0;
FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
@ -164,3 +168,7 @@ void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) {
PassResult MemRefBoundCheck::runOnMLFunction(MLFunction *f) {
return walk(f), success();
}
static PassRegistration<MemRefBoundCheck>
memRefBoundCheck("memref-bound-check",
"Check memref accesses in an MLFunction");

View File

@ -51,10 +51,13 @@ struct MemRefDependenceCheck : public FunctionPass,
loadsAndStores.push_back(opStmt);
}
}
static char passID;
};
} // end anonymous namespace
char MemRefDependenceCheck::passID = 0;
FunctionPass *mlir::createMemRefDependenceCheckPass() {
return new MemRefDependenceCheck();
}
@ -132,3 +135,7 @@ PassResult MemRefDependenceCheck::runOnMLFunction(MLFunction *f) {
checkDependences(loadsAndStores);
return success();
}
static PassRegistration<MemRefDependenceCheck>
pass("memref-dependence-check",
"Checks dependences between all pairs of memref accesses.");

View File

@ -23,6 +23,9 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/Support/PassNameParser.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
/// Out of line virtual method to ensure vtables and metadata are emitted to a
@ -51,3 +54,37 @@ PassResult FunctionPass::runOnFunction(Function *fn) {
return success();
}
// TODO: The pass registry and pass name parsing should be moved out.
static llvm::ManagedStatic<llvm::DenseMap<const void *, PassInfo>> passRegistry;
void mlir::registerPass(StringRef arg, StringRef description,
const void *passID,
const PassAllocatorFunction &function) {
bool inserted = passRegistry
->insert(std::make_pair(
passID, PassInfo(arg, description, passID, function)))
.second;
assert(inserted && "Pass registered multiple times");
(void)inserted;
}
PassNameParser::PassNameParser(llvm::cl::Option &opt)
: llvm::cl::parser<const PassInfo *>(opt) {
for (const auto &kv : *passRegistry) {
addLiteralOption(kv.second.getPassArgument(), &kv.second,
kv.second.getPassDescription());
}
}
void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
size_t GlobalWidth) const {
PassNameParser *TP = const_cast<PassNameParser *>(this);
llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
[](const PassNameParser::OptionInfo *VT1,
const PassNameParser::OptionInfo *VT2) {
return VT1->Name.compare(VT2->Name);
});
using llvm::cl::parser;
parser<const PassInfo *>::printOptionInfo(O, GlobalWidth);
}

View File

@ -74,13 +74,16 @@ void mlir::CFGFunction::viewGraph() const {
namespace {
struct PrintCFGPass : public FunctionPass {
PrintCFGPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title)
PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false,
const llvm::Twine &title = "")
: os(os), shortNames(shortNames), title(title) {}
PassResult runOnCFGFunction(CFGFunction *function) override {
mlir::writeGraph(os, function, shortNames, title);
return success();
}
static char passID;
private:
llvm::raw_ostream &os;
bool shortNames;
@ -88,8 +91,13 @@ private:
};
} // namespace
char PrintCFGPass::passID = 0;
FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os,
bool shortNames,
const llvm::Twine &title) {
return new PrintCFGPass(os, shortNames, title);
}
static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
"Print CFG graph per function");

View File

@ -35,9 +35,13 @@ namespace {
/// Canonicalize operations in functions.
struct Canonicalizer : public FunctionPass {
PassResult runOnFunction(Function *fn) override;
static char passID;
};
} // end anonymous namespace
char Canonicalizer::passID = 0;
PassResult Canonicalizer::runOnFunction(Function *fn) {
auto *context = fn->getContext();
OwningPatternList patterns;
@ -54,3 +58,6 @@ PassResult Canonicalizer::runOnFunction(Function *fn) {
/// Create a Canonicalizer pass.
FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); }
static PassRegistration<Canonicalizer> pass("canonicalize",
"Canonicalize operations");

View File

@ -50,10 +50,14 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
void visitOperationStmt(OperationStmt *stmt);
PassResult runOnMLFunction(MLFunction *f) override;
using StmtWalker<ComposeAffineMaps>::walk;
static char passID;
};
} // end anonymous namespace
char ComposeAffineMaps::passID = 0;
FunctionPass *mlir::createComposeAffineMapsPass() {
return new ComposeAffineMaps();
}
@ -92,3 +96,6 @@ PassResult ComposeAffineMaps::runOnMLFunction(MLFunction *f) {
}
return success();
}
static PassRegistration<ComposeAffineMaps> pass("compose-affine-maps",
"Compose affine maps");

View File

@ -40,9 +40,13 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
void visitForStmt(ForStmt *stmt);
PassResult runOnCFGFunction(CFGFunction *f) override;
PassResult runOnMLFunction(MLFunction *f) override;
static char passID;
};
} // end anonymous namespace
char ConstantFold::passID = 0;
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
@ -174,3 +178,6 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
/// Creates a constant folding pass.
FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); }
static PassRegistration<ConstantFold>
pass("constant-fold", "Constant fold operations in functions");

View File

@ -70,6 +70,8 @@ public:
PassResult runOnModule(Module *m) override;
static char passID;
private:
// Generates CFG functions for all ML functions in the module.
void convertMLFunctions();
@ -90,6 +92,8 @@ private:
};
} // end anonymous namespace
char ModuleConverter::passID = 0;
// Iterates over all functions in the module generating CFG functions
// equivalent to ML functions and replacing references to ML functions
// with references to the generated ML functions.
@ -163,3 +167,7 @@ void ModuleConverter::removeMLFunctions() {
/// Function references are appropriately patched to refer to the newly
/// generated CFG functions.
ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
static PassRegistration<ModuleConverter>
pass("convert-to-cfg",
"Convert all ML functions in the module to CFG ones");

View File

@ -45,6 +45,7 @@ struct LoopFusion : public FunctionPass {
LoopFusion() {}
PassResult runOnMLFunction(MLFunction *f) override;
static char passID;
};
// LoopCollector walks the statements in an MLFunction and builds a map from
@ -75,6 +76,8 @@ public:
} // end anonymous namespace
char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
// TODO(andydavis) Remove the following test code when more general loop
@ -242,3 +245,5 @@ PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
return success();
}
static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");

View File

@ -42,10 +42,14 @@ namespace {
struct LoopTiling : public FunctionPass {
PassResult runOnMLFunction(MLFunction *f) override;
constexpr static unsigned kDefaultTileSize = 32;
static char passID;
};
} // end anonymous namespace
char LoopTiling::passID = 0;
/// Creates a pass to perform loop tiling on all suitable loop nests of an
/// MLFunction.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
@ -238,3 +242,5 @@ PassResult LoopTiling::runOnMLFunction(MLFunction *f) {
}
return success();
}
static PassRegistration<LoopTiling> pass("loop-tile", "Tile loop nests");

View File

@ -56,22 +56,20 @@ struct LoopUnroll : public FunctionPass {
Optional<unsigned> unrollFactor;
Optional<bool> unrollFull;
explicit LoopUnroll(Optional<unsigned> unrollFactor,
Optional<bool> unrollFull)
explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
Optional<bool> unrollFull = None)
: unrollFactor(unrollFactor), unrollFull(unrollFull) {}
PassResult runOnMLFunction(MLFunction *f) override;
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
static char passID;
};
} // end anonymous namespace
FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
return new LoopUnroll(unrollFactor == -1 ? None
: Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull));
}
char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
@ -286,3 +284,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
return true;
}
FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
return new LoopUnroll(unrollFactor == -1 ? None
: Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull));
}
static PassRegistration<LoopUnroll> pass("loop-unroll", "Unroll loops");

View File

@ -70,14 +70,18 @@ struct LoopUnrollAndJam : public FunctionPass {
Optional<unsigned> unrollJamFactor;
static const unsigned kDefaultUnrollJamFactor = 4;
explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor)
explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
: unrollJamFactor(unrollJamFactor) {}
PassResult runOnMLFunction(MLFunction *f) override;
bool runOnForStmt(ForStmt *forStmt);
static char passID;
};
} // end anonymous namespace
char LoopUnrollAndJam::passID = 0;
FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
return new LoopUnrollAndJam(
unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
@ -239,3 +243,6 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
return true;
}
static PassRegistration<LoopUnrollAndJam> pass("loop-unroll-jam",
"Unroll and jam loops");

View File

@ -47,10 +47,14 @@ struct PipelineDataTransfer : public FunctionPass,
// Collect all 'for' statements.
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
std::vector<ForStmt *> forStmts;
static char passID;
};
} // end anonymous namespace
char PipelineDataTransfer::passID = 0;
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
FunctionPass *mlir::createPipelineDataTransferPass() {
@ -306,3 +310,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
return success();
}
static PassRegistration<PipelineDataTransfer> pass(
"pipeline-data-transfer",
"Pipeline non-blocking data transfers between explicitly managed levels of "
"the memory hierarchy");

View File

@ -47,10 +47,14 @@ struct SimplifyAffineStructures : public FunctionPass,
void visitIfStmt(IfStmt *ifStmt);
void visitOperationStmt(OperationStmt *opStmt);
static char passID;
};
} // end anonymous namespace
char SimplifyAffineStructures::passID = 0;
FunctionPass *mlir::createSimplifyAffineStructuresPass() {
return new SimplifyAffineStructures();
}
@ -83,3 +87,6 @@ PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) {
walk(f);
return success();
}
static PassRegistration<SimplifyAffineStructures>
pass("simplify-affine-structures", "Simplify affine expressions");

View File

@ -199,10 +199,14 @@ struct Vectorize : public FunctionPass {
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
MLFunctionMatcherContext MLContext;
static char passID;
};
} // end anonymous namespace
char Vectorize::passID = 0;
/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. //////
namespace {
@ -669,3 +673,7 @@ PassResult Vectorize::runOnMLFunction(MLFunction *f) {
}
FunctionPass *mlir::createVectorizePass() { return new Vectorize(); }
static PassRegistration<Vectorize>
pass("vectorize",
"Vectorize to a target independent n-D vector abstraction");

View File

@ -30,6 +30,7 @@
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass.h"
#include "mlir/Support/PassNameParser.h"
#include "mlir/TensorFlow/ControlFlowOps.h"
#include "mlir/TensorFlow/Passes.h"
#include "mlir/TensorFlowLite/Passes.h"
@ -67,58 +68,7 @@ static cl::opt<bool>
"expected-* lines on the corresponding line"),
cl::init(false));
enum Passes {
Canonicalize,
ComposeAffineMaps,
ConstantFold,
ConvertToCFG,
TFLiteLegaize,
LoopFusion,
LoopTiling,
LoopUnroll,
LoopUnrollAndJam,
MemRefBoundCheck,
MemRefDependenceCheck,
PipelineDataTransfer,
PrintCFGGraph,
SimplifyAffineStructures,
TFRaiseControlFlow,
Vectorize,
XLALower,
};
static cl::list<Passes> passList(
"", cl::desc("Compiler passes to run"),
cl::values(
clEnumValN(Canonicalize, "canonicalize", "Canonicalize operations"),
clEnumValN(ComposeAffineMaps, "compose-affine-maps",
"Compose affine maps"),
clEnumValN(ConstantFold, "constant-fold",
"Constant fold operations in functions"),
clEnumValN(ConvertToCFG, "convert-to-cfg",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
clEnumValN(LoopTiling, "loop-tile", "Tile loop nests"),
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
clEnumValN(MemRefBoundCheck, "memref-bound-check",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(MemRefDependenceCheck, "memref-dependence-check",
"Checks dependences between all pairs of memref accesses."),
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
"Pipeline non-blocking data transfers between"
"explicitly managed levels of the memory hierarchy"),
clEnumValN(PrintCFGGraph, "print-cfg-graph",
"Print CFG graph per function"),
clEnumValN(SimplifyAffineStructures, "simplify-affine-structures",
"Simplify affine expressions"),
clEnumValN(TFLiteLegaize, "tfl-legalize",
"Legalize operations to TensorFlow Lite dialect"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG"),
clEnumValN(Vectorize, "vectorize",
"Vectorize to a target independent n-D vector abstraction."),
clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect")));
static std::vector<const mlir::PassInfo *> *passList;
enum OptResult { OptSuccess, OptFailure };
@ -190,65 +140,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
return OptFailure;
// Run each of the passes that were selected.
for (unsigned i = 0, e = passList.size(); i != e; ++i) {
auto passKind = passList[i];
Pass *pass = nullptr;
switch (passKind) {
case Canonicalize:
pass = createCanonicalizerPass();
break;
case ComposeAffineMaps:
pass = createComposeAffineMapsPass();
break;
case ConstantFold:
pass = createConstantFoldPass();
break;
case ConvertToCFG:
pass = createConvertToCFGPass();
break;
case LoopFusion:
pass = createLoopFusionPass();
break;
case LoopTiling:
pass = createLoopTilingPass();
break;
case LoopUnroll:
pass = createLoopUnrollPass();
break;
case LoopUnrollAndJam:
pass = createLoopUnrollAndJamPass();
break;
case MemRefBoundCheck:
pass = createMemRefBoundCheckPass();
break;
case MemRefDependenceCheck:
pass = createMemRefDependenceCheckPass();
break;
case PipelineDataTransfer:
pass = createPipelineDataTransferPass();
break;
case PrintCFGGraph:
pass = createPrintCFGGraphPass();
break;
case SimplifyAffineStructures:
pass = createSimplifyAffineStructuresPass();
break;
case TFLiteLegaize:
pass = tfl::createLegalizer();
break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();
break;
case Vectorize:
pass = createVectorizePass();
break;
case XLALower:
pass = createXLALowerPass();
break;
}
for (const auto *passInfo : *passList) {
std::unique_ptr<Pass> pass(passInfo->createPass());
PassResult result = pass->runOnModule(module.get());
delete pass;
if (result)
return OptFailure;
@ -468,6 +362,10 @@ int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
InitLLVM y(argc, argv);
// Parse pass names in main to ensure static initialization completed.
llvm::cl::list<const mlir::PassInfo *, bool, mlir::PassNameParser> passList(
"", llvm::cl::desc("Compiler passes to run"));
::passList = &passList;
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
// Set up the input file.