From 81a066e6e74188e018369de650fa7113784c2a9e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar <jpienaar@google.com> Date: Tue, 18 Sep 2018 12:18:24 -0700 Subject: [PATCH] Switch from positional argument to explicit flags for mlir-translate This results in uniform behavior with mlir-opt. Exactly one transformation is allowed. PiperOrigin-RevId: 213493415 --- mlir/test/IR/parser.mlir | 2 +- mlir/tools/mlir-translate/mlir-translate.cpp | 80 ++++++++++---------- mlir/tools/mlir-translate/mlir-translate.h | 8 +- 3 files changed, 44 insertions(+), 46 deletions(-) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index f0fe88007185..5a19a32420f6 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate mlir-to-mlir %s -o - | FileCheck %s +// RUN: mlir-translate -mlir-to-mlir %s -o - | FileCheck %s // CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4) #map0 = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4) diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp index 2b545cac8ee7..37019fa05cbc 100644 --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -35,10 +35,6 @@ using namespace mlir; -static llvm::cl::opt<std::string> - translationRequested(llvm::cl::Positional, - llvm::cl::desc("<translation-requested>")); - static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-")); @@ -47,25 +43,6 @@ static llvm::cl::opt<std::string> outputFilename("o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("-")); -// Static map between translations registered and the TranslateFunctions that -// perform those translations. -llvm::ManagedStatic<llvm::StringMap<TranslateFunction>> translations; - -TranslateRegistration::TranslateRegistration( - llvm::StringRef name, const TranslateFunction &function) { - if (translations->find(name) != translations->end()) - llvm::report_fatal_error("Attempting to overwrite an existing function"); - assert(function && "Attempting to register an empty translate function"); - (*translations)[name] = function; -} - -TranslateFunction getTranslation(llvm::StringRef name) { - auto it = translations->find(name); - if (it == translations->end()) - return nullptr; - return it->second; -} - extern void initializeMLIRContext(MLIRContext *ctx); Module *mlir::parseMLIRInput(StringRef inputFilename, MLIRContext *context) { @@ -111,31 +88,52 @@ static TranslateRegistration MLIRToMLIRTranslate( return printMLIROutput(*module, outputFilename); }); -// Returns a comma-separated sorted list of the registered translations. -static std::string registeredTranslationNames() { - std::vector<StringRef> keys(translations->keys().begin(), - translations->keys().end()); - llvm::sort(keys); - return llvm::join(keys, ", "); +// Static map between translations registered and the TranslateFunctions that +// perform those translations. +static llvm::ManagedStatic<llvm::StringMap<TranslateFunction>> + translationRegistry; + +TranslateRegistration::TranslateRegistration( + llvm::StringRef name, const TranslateFunction &function) { + if (translationRegistry->find(name) != translationRegistry->end()) + llvm::report_fatal_error("Attempting to overwrite an existing function"); + assert(function && "Attempting to register an empty translate function"); + (*translationRegistry)[name] = function; } +// Custom parser for TranslateFunction. +struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> { + TranslationParser(llvm::cl::Option &opt) + : llvm::cl::parser<const TranslateFunction *>(opt) { + for (const auto &kv : *translationRegistry) { + addLiteralOption(kv.first(), &kv.second, kv.first()); + } + } + + void printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const override { + TranslationParser *TP = const_cast<TranslationParser *>(this); + llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), + [](const TranslationParser::OptionInfo *VT1, + const TranslationParser::OptionInfo *VT2) { + return VT1->Name.compare(VT2->Name); + }); + using llvm::cl::parser; + parser<const TranslateFunction *>::printOptionInfo(O, GlobalWidth); + } +}; + int main(int argc, char **argv) { llvm::PrettyStackTraceProgram x(argc, argv); llvm::InitLLVM y(argc, argv); - llvm::cl::ParseCommandLineOptions( - argc, argv, - "MLIR translation driver\n\nRegistered translations:\n\t" + - registeredTranslationNames() + "\n"); - - auto translate = getTranslation(translationRequested); - if (!translate) { - llvm::errs() << "Translation requested '" << translationRequested - << "' not registered\n"; - return 1; - } + // Add flags for all the registered translations. + llvm::cl::opt<const TranslateFunction *, false, TranslationParser> + translationRequested("", llvm::cl::desc("Translation to perform"), + llvm::cl::Required); + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n"); MLIRContext context; initializeMLIRContext(&context); - return translate(inputFilename, outputFilename, &context); + return (*translationRequested)(inputFilename, outputFilename, &context); } diff --git a/mlir/tools/mlir-translate/mlir-translate.h b/mlir/tools/mlir-translate/mlir-translate.h index df6694801f34..5979bea9bbbf 100644 --- a/mlir/tools/mlir-translate/mlir-translate.h +++ b/mlir/tools/mlir-translate/mlir-translate.h @@ -32,14 +32,14 @@ using TranslateFunction = std::function<bool(llvm::StringRef inputFilename, llvm::StringRef oututFilename, MLIRContext *)>; -// Use TranslateRegistration as a global initialiser that registers a -// function and associates it with name. This requires that a command -// has not been registered to a given name. +// Use TranslateRegistration as a global initialiser that registers a function +// and associates it with name. This requires that a translation has not been +// registered to a given name. // // Usage: // // // At namespace scope. -// static CommandRegistration Unused(&MySubCommand, [] { ... }); +// static TranslateRegistration Unused(&MySubCommand, [] { ... }); // struct TranslateRegistration { TranslateRegistration(llvm::StringRef name,