forked from OSchip/llvm-project
Add static pass registration
Add static pass registration and change mlir-opt to use it. Future work is needed to refactor the registration for PassManager usage. Change build targets to alwayslink to enforce registration. PiperOrigin-RevId: 220390178
This commit is contained in:
parent
559e816f3f
commit
6f0fb22723
|
@ -18,7 +18,10 @@
|
|||
#ifndef MLIR_PASS_H
|
||||
#define MLIR_PASS_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
#include <functional>
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
|
@ -81,6 +84,65 @@ public:
|
|||
virtual PassResult runOnModule(Module *m) override;
|
||||
};
|
||||
|
||||
using PassAllocatorFunction = std::function<Pass *()>;
|
||||
|
||||
/// Structure to group information about a pass (argument to invoke via
|
||||
/// mlir-opt, description, pass allocator and unique ID).
|
||||
class PassInfo {
|
||||
public:
|
||||
/// PassInfo constructor should not be invoked directly, instead use
|
||||
/// PassRegistration or registerPass.
|
||||
PassInfo(StringRef arg, StringRef description, const void *passID,
|
||||
PassAllocatorFunction allocator)
|
||||
: arg(arg), description(description), allocator(allocator),
|
||||
passID(passID){};
|
||||
|
||||
/// Returns an allocated instance of this pass.
|
||||
Pass *createPass() const {
|
||||
assert(allocator &&
|
||||
"Cannot call createPass on PassInfo without default allocator");
|
||||
return allocator();
|
||||
}
|
||||
|
||||
/// Returns the command line option that may be passed to 'mlir-opt' that will
|
||||
/// cause this pass to run or null if there is no such argument.
|
||||
StringRef getPassArgument() const { return arg; }
|
||||
|
||||
/// Returns a description for the pass, this never returns null.
|
||||
StringRef getPassDescription() const { return description; }
|
||||
|
||||
private:
|
||||
// The argument with which to invoke the pass via mlir-opt.
|
||||
StringRef arg;
|
||||
|
||||
// Description of the pass.
|
||||
StringRef description;
|
||||
|
||||
// Allocator to construct an instance of this pass.
|
||||
PassAllocatorFunction allocator;
|
||||
|
||||
// Unique identifier for pass.
|
||||
const void *passID;
|
||||
};
|
||||
|
||||
/// Register a specific dialect creation function with the system, typically
|
||||
/// used through the PassRegistration template.
|
||||
void registerPass(StringRef arg, StringRef description, const void *passID,
|
||||
const PassAllocatorFunction &function);
|
||||
|
||||
/// PassRegistration provides a global initializer that registers a Pass
|
||||
/// allocation routine.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// // At namespace scope.
|
||||
/// static PassRegistration<MyPass> Unused("unused", "Unused pass");
|
||||
template <typename ConcretePass> struct PassRegistration {
|
||||
PassRegistration(StringRef arg, StringRef description) {
|
||||
registerPass(arg, description, &ConcretePass::passID,
|
||||
[&]() { return new ConcretePass(); });
|
||||
}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_PASS_H
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
//===- PassNameParser.h - Base classes for compiler passes ------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// The PassNameParser class adds all passes linked in to the system that are
|
||||
// creatable to the tool.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_SUPPORT_PASSNAMEPARSER_H_
|
||||
#define MLIR_SUPPORT_PASSNAMEPARSER_H_
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
namespace mlir {
|
||||
class PassInfo;
|
||||
|
||||
/// Adds command line option for each registered pass.
|
||||
struct PassNameParser : public llvm::cl::parser<const PassInfo *> {
|
||||
PassNameParser(llvm::cl::Option &opt);
|
||||
|
||||
void printOptionInfo(const llvm::cl::Option &O,
|
||||
size_t GlobalWidth) const override;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_SUPPORT_PASSNAMEPARSER_H_
|
|
@ -45,10 +45,14 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
|
|||
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
|
||||
|
||||
void visitOperationStmt(OperationStmt *opStmt);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char MemRefBoundCheck::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createMemRefBoundCheckPass() {
|
||||
return new MemRefBoundCheck();
|
||||
}
|
||||
|
@ -164,3 +168,7 @@ void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) {
|
|||
PassResult MemRefBoundCheck::runOnMLFunction(MLFunction *f) {
|
||||
return walk(f), success();
|
||||
}
|
||||
|
||||
static PassRegistration<MemRefBoundCheck>
|
||||
memRefBoundCheck("memref-bound-check",
|
||||
"Check memref accesses in an MLFunction");
|
||||
|
|
|
@ -51,10 +51,13 @@ struct MemRefDependenceCheck : public FunctionPass,
|
|||
loadsAndStores.push_back(opStmt);
|
||||
}
|
||||
}
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char MemRefDependenceCheck::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createMemRefDependenceCheckPass() {
|
||||
return new MemRefDependenceCheck();
|
||||
}
|
||||
|
@ -132,3 +135,7 @@ PassResult MemRefDependenceCheck::runOnMLFunction(MLFunction *f) {
|
|||
checkDependences(loadsAndStores);
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<MemRefDependenceCheck>
|
||||
pass("memref-dependence-check",
|
||||
"Checks dependences between all pairs of memref accesses.");
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Support/PassNameParser.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
using namespace mlir;
|
||||
|
||||
/// Out of line virtual method to ensure vtables and metadata are emitted to a
|
||||
|
@ -51,3 +54,37 @@ PassResult FunctionPass::runOnFunction(Function *fn) {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: The pass registry and pass name parsing should be moved out.
|
||||
static llvm::ManagedStatic<llvm::DenseMap<const void *, PassInfo>> passRegistry;
|
||||
|
||||
void mlir::registerPass(StringRef arg, StringRef description,
|
||||
const void *passID,
|
||||
const PassAllocatorFunction &function) {
|
||||
bool inserted = passRegistry
|
||||
->insert(std::make_pair(
|
||||
passID, PassInfo(arg, description, passID, function)))
|
||||
.second;
|
||||
assert(inserted && "Pass registered multiple times");
|
||||
(void)inserted;
|
||||
}
|
||||
|
||||
PassNameParser::PassNameParser(llvm::cl::Option &opt)
|
||||
: llvm::cl::parser<const PassInfo *>(opt) {
|
||||
for (const auto &kv : *passRegistry) {
|
||||
addLiteralOption(kv.second.getPassArgument(), &kv.second,
|
||||
kv.second.getPassDescription());
|
||||
}
|
||||
}
|
||||
|
||||
void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
|
||||
size_t GlobalWidth) const {
|
||||
PassNameParser *TP = const_cast<PassNameParser *>(this);
|
||||
llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
|
||||
[](const PassNameParser::OptionInfo *VT1,
|
||||
const PassNameParser::OptionInfo *VT2) {
|
||||
return VT1->Name.compare(VT2->Name);
|
||||
});
|
||||
using llvm::cl::parser;
|
||||
parser<const PassInfo *>::printOptionInfo(O, GlobalWidth);
|
||||
}
|
||||
|
|
|
@ -74,13 +74,16 @@ void mlir::CFGFunction::viewGraph() const {
|
|||
|
||||
namespace {
|
||||
struct PrintCFGPass : public FunctionPass {
|
||||
PrintCFGPass(llvm::raw_ostream &os, bool shortNames, const llvm::Twine &title)
|
||||
PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false,
|
||||
const llvm::Twine &title = "")
|
||||
: os(os), shortNames(shortNames), title(title) {}
|
||||
PassResult runOnCFGFunction(CFGFunction *function) override {
|
||||
mlir::writeGraph(os, function, shortNames, title);
|
||||
return success();
|
||||
}
|
||||
|
||||
static char passID;
|
||||
|
||||
private:
|
||||
llvm::raw_ostream &os;
|
||||
bool shortNames;
|
||||
|
@ -88,8 +91,13 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
char PrintCFGPass::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os,
|
||||
bool shortNames,
|
||||
const llvm::Twine &title) {
|
||||
return new PrintCFGPass(os, shortNames, title);
|
||||
}
|
||||
|
||||
static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
|
||||
"Print CFG graph per function");
|
||||
|
|
|
@ -35,9 +35,13 @@ namespace {
|
|||
/// Canonicalize operations in functions.
|
||||
struct Canonicalizer : public FunctionPass {
|
||||
PassResult runOnFunction(Function *fn) override;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
char Canonicalizer::passID = 0;
|
||||
|
||||
PassResult Canonicalizer::runOnFunction(Function *fn) {
|
||||
auto *context = fn->getContext();
|
||||
OwningPatternList patterns;
|
||||
|
@ -54,3 +58,6 @@ PassResult Canonicalizer::runOnFunction(Function *fn) {
|
|||
|
||||
/// Create a Canonicalizer pass.
|
||||
FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); }
|
||||
|
||||
static PassRegistration<Canonicalizer> pass("canonicalize",
|
||||
"Canonicalize operations");
|
||||
|
|
|
@ -50,10 +50,14 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
|
|||
void visitOperationStmt(OperationStmt *stmt);
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
using StmtWalker<ComposeAffineMaps>::walk;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char ComposeAffineMaps::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createComposeAffineMapsPass() {
|
||||
return new ComposeAffineMaps();
|
||||
}
|
||||
|
@ -92,3 +96,6 @@ PassResult ComposeAffineMaps::runOnMLFunction(MLFunction *f) {
|
|||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<ComposeAffineMaps> pass("compose-affine-maps",
|
||||
"Compose affine maps");
|
||||
|
|
|
@ -40,9 +40,13 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
|||
void visitForStmt(ForStmt *stmt);
|
||||
PassResult runOnCFGFunction(CFGFunction *f) override;
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
char ConstantFold::passID = 0;
|
||||
|
||||
/// Attempt to fold the specified operation, updating the IR to match. If
|
||||
/// constants are found, we keep track of them in the existingConstants list.
|
||||
///
|
||||
|
@ -174,3 +178,6 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
|
|||
|
||||
/// Creates a constant folding pass.
|
||||
FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); }
|
||||
|
||||
static PassRegistration<ConstantFold>
|
||||
pass("constant-fold", "Constant fold operations in functions");
|
||||
|
|
|
@ -70,6 +70,8 @@ public:
|
|||
|
||||
PassResult runOnModule(Module *m) override;
|
||||
|
||||
static char passID;
|
||||
|
||||
private:
|
||||
// Generates CFG functions for all ML functions in the module.
|
||||
void convertMLFunctions();
|
||||
|
@ -90,6 +92,8 @@ private:
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
char ModuleConverter::passID = 0;
|
||||
|
||||
// Iterates over all functions in the module generating CFG functions
|
||||
// equivalent to ML functions and replacing references to ML functions
|
||||
// with references to the generated ML functions.
|
||||
|
@ -163,3 +167,7 @@ void ModuleConverter::removeMLFunctions() {
|
|||
/// Function references are appropriately patched to refer to the newly
|
||||
/// generated CFG functions.
|
||||
ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
|
||||
|
||||
static PassRegistration<ModuleConverter>
|
||||
pass("convert-to-cfg",
|
||||
"Convert all ML functions in the module to CFG ones");
|
||||
|
|
|
@ -45,6 +45,7 @@ struct LoopFusion : public FunctionPass {
|
|||
LoopFusion() {}
|
||||
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
static char passID;
|
||||
};
|
||||
|
||||
// LoopCollector walks the statements in an MLFunction and builds a map from
|
||||
|
@ -75,6 +76,8 @@ public:
|
|||
|
||||
} // end anonymous namespace
|
||||
|
||||
char LoopFusion::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||
|
||||
// TODO(andydavis) Remove the following test code when more general loop
|
||||
|
@ -242,3 +245,5 @@ PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");
|
||||
|
|
|
@ -42,10 +42,14 @@ namespace {
|
|||
struct LoopTiling : public FunctionPass {
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
constexpr static unsigned kDefaultTileSize = 32;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char LoopTiling::passID = 0;
|
||||
|
||||
/// Creates a pass to perform loop tiling on all suitable loop nests of an
|
||||
/// MLFunction.
|
||||
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
|
||||
|
@ -238,3 +242,5 @@ PassResult LoopTiling::runOnMLFunction(MLFunction *f) {
|
|||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<LoopTiling> pass("loop-tile", "Tile loop nests");
|
||||
|
|
|
@ -56,22 +56,20 @@ struct LoopUnroll : public FunctionPass {
|
|||
Optional<unsigned> unrollFactor;
|
||||
Optional<bool> unrollFull;
|
||||
|
||||
explicit LoopUnroll(Optional<unsigned> unrollFactor,
|
||||
Optional<bool> unrollFull)
|
||||
explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
|
||||
Optional<bool> unrollFull = None)
|
||||
: unrollFactor(unrollFactor), unrollFull(unrollFull) {}
|
||||
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
|
||||
/// Unroll this for stmt. Returns false if nothing was done.
|
||||
bool runOnForStmt(ForStmt *forStmt);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
|
||||
return new LoopUnroll(unrollFactor == -1 ? None
|
||||
: Optional<unsigned>(unrollFactor),
|
||||
unrollFull == -1 ? None : Optional<bool>(unrollFull));
|
||||
}
|
||||
char LoopUnroll::passID = 0;
|
||||
|
||||
PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
|
@ -286,3 +284,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
|
||||
return new LoopUnroll(unrollFactor == -1 ? None
|
||||
: Optional<unsigned>(unrollFactor),
|
||||
unrollFull == -1 ? None : Optional<bool>(unrollFull));
|
||||
}
|
||||
|
||||
static PassRegistration<LoopUnroll> pass("loop-unroll", "Unroll loops");
|
||||
|
|
|
@ -70,14 +70,18 @@ struct LoopUnrollAndJam : public FunctionPass {
|
|||
Optional<unsigned> unrollJamFactor;
|
||||
static const unsigned kDefaultUnrollJamFactor = 4;
|
||||
|
||||
explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor)
|
||||
explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
|
||||
: unrollJamFactor(unrollJamFactor) {}
|
||||
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
bool runOnForStmt(ForStmt *forStmt);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
char LoopUnrollAndJam::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
|
||||
return new LoopUnrollAndJam(
|
||||
unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
|
||||
|
@ -239,3 +243,6 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
static PassRegistration<LoopUnrollAndJam> pass("loop-unroll-jam",
|
||||
"Unroll and jam loops");
|
||||
|
|
|
@ -47,10 +47,14 @@ struct PipelineDataTransfer : public FunctionPass,
|
|||
// Collect all 'for' statements.
|
||||
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
|
||||
std::vector<ForStmt *> forStmts;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char PipelineDataTransfer::passID = 0;
|
||||
|
||||
/// Creates a pass to pipeline explicit movement of data across levels of the
|
||||
/// memory hierarchy.
|
||||
FunctionPass *mlir::createPipelineDataTransferPass() {
|
||||
|
@ -306,3 +310,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<PipelineDataTransfer> pass(
|
||||
"pipeline-data-transfer",
|
||||
"Pipeline non-blocking data transfers between explicitly managed levels of "
|
||||
"the memory hierarchy");
|
||||
|
|
|
@ -47,10 +47,14 @@ struct SimplifyAffineStructures : public FunctionPass,
|
|||
|
||||
void visitIfStmt(IfStmt *ifStmt);
|
||||
void visitOperationStmt(OperationStmt *opStmt);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char SimplifyAffineStructures::passID = 0;
|
||||
|
||||
FunctionPass *mlir::createSimplifyAffineStructuresPass() {
|
||||
return new SimplifyAffineStructures();
|
||||
}
|
||||
|
@ -83,3 +87,6 @@ PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) {
|
|||
walk(f);
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<SimplifyAffineStructures>
|
||||
pass("simplify-affine-structures", "Simplify affine expressions");
|
||||
|
|
|
@ -199,10 +199,14 @@ struct Vectorize : public FunctionPass {
|
|||
|
||||
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
||||
MLFunctionMatcherContext MLContext;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char Vectorize::passID = 0;
|
||||
|
||||
/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. //////
|
||||
namespace {
|
||||
|
||||
|
@ -669,3 +673,7 @@ PassResult Vectorize::runOnMLFunction(MLFunction *f) {
|
|||
}
|
||||
|
||||
FunctionPass *mlir::createVectorizePass() { return new Vectorize(); }
|
||||
|
||||
static PassRegistration<Vectorize>
|
||||
pass("vectorize",
|
||||
"Vectorize to a target independent n-D vector abstraction");
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Support/PassNameParser.h"
|
||||
#include "mlir/TensorFlow/ControlFlowOps.h"
|
||||
#include "mlir/TensorFlow/Passes.h"
|
||||
#include "mlir/TensorFlowLite/Passes.h"
|
||||
|
@ -67,58 +68,7 @@ static cl::opt<bool>
|
|||
"expected-* lines on the corresponding line"),
|
||||
cl::init(false));
|
||||
|
||||
enum Passes {
|
||||
Canonicalize,
|
||||
ComposeAffineMaps,
|
||||
ConstantFold,
|
||||
ConvertToCFG,
|
||||
TFLiteLegaize,
|
||||
LoopFusion,
|
||||
LoopTiling,
|
||||
LoopUnroll,
|
||||
LoopUnrollAndJam,
|
||||
MemRefBoundCheck,
|
||||
MemRefDependenceCheck,
|
||||
PipelineDataTransfer,
|
||||
PrintCFGGraph,
|
||||
SimplifyAffineStructures,
|
||||
TFRaiseControlFlow,
|
||||
Vectorize,
|
||||
XLALower,
|
||||
};
|
||||
|
||||
static cl::list<Passes> passList(
|
||||
"", cl::desc("Compiler passes to run"),
|
||||
cl::values(
|
||||
clEnumValN(Canonicalize, "canonicalize", "Canonicalize operations"),
|
||||
clEnumValN(ComposeAffineMaps, "compose-affine-maps",
|
||||
"Compose affine maps"),
|
||||
clEnumValN(ConstantFold, "constant-fold",
|
||||
"Constant fold operations in functions"),
|
||||
clEnumValN(ConvertToCFG, "convert-to-cfg",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
|
||||
clEnumValN(LoopTiling, "loop-tile", "Tile loop nests"),
|
||||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
|
||||
clEnumValN(MemRefBoundCheck, "memref-bound-check",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(MemRefDependenceCheck, "memref-dependence-check",
|
||||
"Checks dependences between all pairs of memref accesses."),
|
||||
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
|
||||
"Pipeline non-blocking data transfers between"
|
||||
"explicitly managed levels of the memory hierarchy"),
|
||||
clEnumValN(PrintCFGGraph, "print-cfg-graph",
|
||||
"Print CFG graph per function"),
|
||||
clEnumValN(SimplifyAffineStructures, "simplify-affine-structures",
|
||||
"Simplify affine expressions"),
|
||||
clEnumValN(TFLiteLegaize, "tfl-legalize",
|
||||
"Legalize operations to TensorFlow Lite dialect"),
|
||||
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
|
||||
"Dynamic TensorFlow Switch/Match nodes to a CFG"),
|
||||
clEnumValN(Vectorize, "vectorize",
|
||||
"Vectorize to a target independent n-D vector abstraction."),
|
||||
clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect")));
|
||||
static std::vector<const mlir::PassInfo *> *passList;
|
||||
|
||||
enum OptResult { OptSuccess, OptFailure };
|
||||
|
||||
|
@ -190,65 +140,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
|
|||
return OptFailure;
|
||||
|
||||
// Run each of the passes that were selected.
|
||||
for (unsigned i = 0, e = passList.size(); i != e; ++i) {
|
||||
auto passKind = passList[i];
|
||||
Pass *pass = nullptr;
|
||||
switch (passKind) {
|
||||
case Canonicalize:
|
||||
pass = createCanonicalizerPass();
|
||||
break;
|
||||
case ComposeAffineMaps:
|
||||
pass = createComposeAffineMapsPass();
|
||||
break;
|
||||
case ConstantFold:
|
||||
pass = createConstantFoldPass();
|
||||
break;
|
||||
case ConvertToCFG:
|
||||
pass = createConvertToCFGPass();
|
||||
break;
|
||||
case LoopFusion:
|
||||
pass = createLoopFusionPass();
|
||||
break;
|
||||
case LoopTiling:
|
||||
pass = createLoopTilingPass();
|
||||
break;
|
||||
case LoopUnroll:
|
||||
pass = createLoopUnrollPass();
|
||||
break;
|
||||
case LoopUnrollAndJam:
|
||||
pass = createLoopUnrollAndJamPass();
|
||||
break;
|
||||
case MemRefBoundCheck:
|
||||
pass = createMemRefBoundCheckPass();
|
||||
break;
|
||||
case MemRefDependenceCheck:
|
||||
pass = createMemRefDependenceCheckPass();
|
||||
break;
|
||||
case PipelineDataTransfer:
|
||||
pass = createPipelineDataTransferPass();
|
||||
break;
|
||||
case PrintCFGGraph:
|
||||
pass = createPrintCFGGraphPass();
|
||||
break;
|
||||
case SimplifyAffineStructures:
|
||||
pass = createSimplifyAffineStructuresPass();
|
||||
break;
|
||||
case TFLiteLegaize:
|
||||
pass = tfl::createLegalizer();
|
||||
break;
|
||||
case TFRaiseControlFlow:
|
||||
pass = createRaiseTFControlFlowPass();
|
||||
break;
|
||||
case Vectorize:
|
||||
pass = createVectorizePass();
|
||||
break;
|
||||
case XLALower:
|
||||
pass = createXLALowerPass();
|
||||
break;
|
||||
}
|
||||
|
||||
for (const auto *passInfo : *passList) {
|
||||
std::unique_ptr<Pass> pass(passInfo->createPass());
|
||||
PassResult result = pass->runOnModule(module.get());
|
||||
delete pass;
|
||||
if (result)
|
||||
return OptFailure;
|
||||
|
||||
|
@ -468,6 +362,10 @@ int main(int argc, char **argv) {
|
|||
llvm::PrettyStackTraceProgram x(argc, argv);
|
||||
InitLLVM y(argc, argv);
|
||||
|
||||
// Parse pass names in main to ensure static initialization completed.
|
||||
llvm::cl::list<const mlir::PassInfo *, bool, mlir::PassNameParser> passList(
|
||||
"", llvm::cl::desc("Compiler passes to run"));
|
||||
::passList = &passList;
|
||||
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
|
||||
|
||||
// Set up the input file.
|
||||
|
|
Loading…
Reference in New Issue