llvm-project/mlir/tools/mlir-tblgen/PassGen.cpp

122 lines
4.7 KiB
C++

//===- Pass.cpp - MLIR pass registration generator ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PassGen uses the description of passes to generate base classes for passes
// and command line registration.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Pass.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// GEN: Pass base class generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate the start of a pass base class.
///
/// {0}: The def name of the pass record.
/// {1}: The command line argument for the pass.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
#ifdef GEN_PASS_{0}
/// Returns the command-line argument attached to this pass.
static StringRef getPassArgument() { return "{1}"; }
)";
/// The code snippet used to generate the end of a pass base class.
///
/// {0}: The def name of the pass record.
const char *const passDeclEnd = R"(
#undef GEN_PASS_{0}
#endif // GEN_PASS_{0}
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "Pass::" << (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", llvm::cl::desc(\"{3}\")",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", llvm::cl::init(" << defaultVal << ")";
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
os << ", " << *additionalFlags;
os << "};\n";
}
}
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(" Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(),
stat.getDescription());
}
}
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
os << llvm::formatv(passDeclBegin, defName, pass.getArgument());
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << llvm::formatv(passDeclEnd, defName);
}
//===----------------------------------------------------------------------===//
// GEN: Pass registration generation
//===----------------------------------------------------------------------===//
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> "
"std::unique_ptr<Pass> {{ return {2}; });\n",
pass.getArgument(), pass.getSummary(),
pass.getConstructor());
}
os << "#undef GEN_PASS_REGISTRATION\n";
os << "#endif // GEN_PASS_REGISTRATION\n";
}
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("Pass")) {
Pass pass(def);
passes.push_back(pass);
emitPassDecl(pass, os);
}
emitRegistration(passes, os);
}
static mlir::GenRegistration
genRegister("gen-pass-decls", "Generate operation documentation",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitDecls(records, os);
return false;
});