Switch from positional argument to explicit flags for mlir-translate

This results in uniform behavior with mlir-opt. Exactly one transformation is allowed.

PiperOrigin-RevId: 213493415
This commit is contained in:
Jacques Pienaar 2018-09-18 12:18:24 -07:00 committed by jpienaar
parent ab4797229c
commit 81a066e6e7
3 changed files with 44 additions and 46 deletions

View File

@ -1,4 +1,4 @@
// RUN: mlir-translate mlir-to-mlir %s -o - | FileCheck %s
// RUN: mlir-translate -mlir-to-mlir %s -o - | FileCheck %s
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4)
#map0 = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4)

View File

@ -35,10 +35,6 @@
using namespace mlir;
static llvm::cl::opt<std::string>
translationRequested(llvm::cl::Positional,
llvm::cl::desc("<translation-requested>"));
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
@ -47,25 +43,6 @@ static llvm::cl::opt<std::string>
outputFilename("o", llvm::cl::desc("Output filename"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
// Static map between translations registered and the TranslateFunctions that
// perform those translations.
llvm::ManagedStatic<llvm::StringMap<TranslateFunction>> translations;
TranslateRegistration::TranslateRegistration(
llvm::StringRef name, const TranslateFunction &function) {
if (translations->find(name) != translations->end())
llvm::report_fatal_error("Attempting to overwrite an existing function");
assert(function && "Attempting to register an empty translate function");
(*translations)[name] = function;
}
TranslateFunction getTranslation(llvm::StringRef name) {
auto it = translations->find(name);
if (it == translations->end())
return nullptr;
return it->second;
}
extern void initializeMLIRContext(MLIRContext *ctx);
Module *mlir::parseMLIRInput(StringRef inputFilename, MLIRContext *context) {
@ -111,31 +88,52 @@ static TranslateRegistration MLIRToMLIRTranslate(
return printMLIROutput(*module, outputFilename);
});
// Returns a comma-separated sorted list of the registered translations.
static std::string registeredTranslationNames() {
std::vector<StringRef> keys(translations->keys().begin(),
translations->keys().end());
llvm::sort(keys);
return llvm::join(keys, ", ");
// Static map between translations registered and the TranslateFunctions that
// perform those translations.
static llvm::ManagedStatic<llvm::StringMap<TranslateFunction>>
translationRegistry;
TranslateRegistration::TranslateRegistration(
llvm::StringRef name, const TranslateFunction &function) {
if (translationRegistry->find(name) != translationRegistry->end())
llvm::report_fatal_error("Attempting to overwrite an existing function");
assert(function && "Attempting to register an empty translate function");
(*translationRegistry)[name] = function;
}
// Custom parser for TranslateFunction.
struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
TranslationParser(llvm::cl::Option &opt)
: llvm::cl::parser<const TranslateFunction *>(opt) {
for (const auto &kv : *translationRegistry) {
addLiteralOption(kv.first(), &kv.second, kv.first());
}
}
void printOptionInfo(const llvm::cl::Option &O,
size_t GlobalWidth) const override {
TranslationParser *TP = const_cast<TranslationParser *>(this);
llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
[](const TranslationParser::OptionInfo *VT1,
const TranslationParser::OptionInfo *VT2) {
return VT1->Name.compare(VT2->Name);
});
using llvm::cl::parser;
parser<const TranslateFunction *>::printOptionInfo(O, GlobalWidth);
}
};
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(
argc, argv,
"MLIR translation driver\n\nRegistered translations:\n\t" +
registeredTranslationNames() + "\n");
auto translate = getTranslation(translationRequested);
if (!translate) {
llvm::errs() << "Translation requested '" << translationRequested
<< "' not registered\n";
return 1;
}
// Add flags for all the registered translations.
llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
translationRequested("", llvm::cl::desc("Translation to perform"),
llvm::cl::Required);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
MLIRContext context;
initializeMLIRContext(&context);
return translate(inputFilename, outputFilename, &context);
return (*translationRequested)(inputFilename, outputFilename, &context);
}

View File

@ -32,14 +32,14 @@ using TranslateFunction =
std::function<bool(llvm::StringRef inputFilename,
llvm::StringRef oututFilename, MLIRContext *)>;
// Use TranslateRegistration as a global initialiser that registers a
// function and associates it with name. This requires that a command
// has not been registered to a given name.
// Use TranslateRegistration as a global initialiser that registers a function
// and associates it with name. This requires that a translation has not been
// registered to a given name.
//
// Usage:
//
// // At namespace scope.
// static CommandRegistration Unused(&MySubCommand, [] { ... });
// static TranslateRegistration Unused(&MySubCommand, [] { ... });
//
struct TranslateRegistration {
TranslateRegistration(llvm::StringRef name,