From 81a066e6e74188e018369de650fa7113784c2a9e Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar@google.com>
Date: Tue, 18 Sep 2018 12:18:24 -0700
Subject: [PATCH] Switch from positional argument to explicit flags for
 mlir-translate

This results in uniform behavior with mlir-opt. Exactly one transformation is allowed.

PiperOrigin-RevId: 213493415
---
 mlir/test/IR/parser.mlir                     |  2 +-
 mlir/tools/mlir-translate/mlir-translate.cpp | 80 ++++++++++----------
 mlir/tools/mlir-translate/mlir-translate.h   |  8 +-
 3 files changed, 44 insertions(+), 46 deletions(-)

diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index f0fe88007185..5a19a32420f6 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate mlir-to-mlir %s -o - | FileCheck %s
+// RUN: mlir-translate -mlir-to-mlir %s -o - | FileCheck %s
 
 // CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4)
 #map0 = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d3, d4)
diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 2b545cac8ee7..37019fa05cbc 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -35,10 +35,6 @@
 
 using namespace mlir;
 
-static llvm::cl::opt<std::string>
-    translationRequested(llvm::cl::Positional,
-                         llvm::cl::desc("<translation-requested>"));
-
 static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
                                                 llvm::cl::desc("<input file>"),
                                                 llvm::cl::init("-"));
@@ -47,25 +43,6 @@ static llvm::cl::opt<std::string>
     outputFilename("o", llvm::cl::desc("Output filename"),
                    llvm::cl::value_desc("filename"), llvm::cl::init("-"));
 
-// Static map between translations registered and the TranslateFunctions that
-// perform those translations.
-llvm::ManagedStatic<llvm::StringMap<TranslateFunction>> translations;
-
-TranslateRegistration::TranslateRegistration(
-    llvm::StringRef name, const TranslateFunction &function) {
-  if (translations->find(name) != translations->end())
-    llvm::report_fatal_error("Attempting to overwrite an existing function");
-  assert(function && "Attempting to register an empty translate function");
-  (*translations)[name] = function;
-}
-
-TranslateFunction getTranslation(llvm::StringRef name) {
-  auto it = translations->find(name);
-  if (it == translations->end())
-    return nullptr;
-  return it->second;
-}
-
 extern void initializeMLIRContext(MLIRContext *ctx);
 
 Module *mlir::parseMLIRInput(StringRef inputFilename, MLIRContext *context) {
@@ -111,31 +88,52 @@ static TranslateRegistration MLIRToMLIRTranslate(
       return printMLIROutput(*module, outputFilename);
     });
 
-// Returns a comma-separated sorted list of the registered translations.
-static std::string registeredTranslationNames() {
-  std::vector<StringRef> keys(translations->keys().begin(),
-                              translations->keys().end());
-  llvm::sort(keys);
-  return llvm::join(keys, ", ");
+// Static map between translations registered and the TranslateFunctions that
+// perform those translations.
+static llvm::ManagedStatic<llvm::StringMap<TranslateFunction>>
+    translationRegistry;
+
+TranslateRegistration::TranslateRegistration(
+    llvm::StringRef name, const TranslateFunction &function) {
+  if (translationRegistry->find(name) != translationRegistry->end())
+    llvm::report_fatal_error("Attempting to overwrite an existing function");
+  assert(function && "Attempting to register an empty translate function");
+  (*translationRegistry)[name] = function;
 }
 
+// Custom parser for TranslateFunction.
+struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
+  TranslationParser(llvm::cl::Option &opt)
+      : llvm::cl::parser<const TranslateFunction *>(opt) {
+    for (const auto &kv : *translationRegistry) {
+      addLiteralOption(kv.first(), &kv.second, kv.first());
+    }
+  }
+
+  void printOptionInfo(const llvm::cl::Option &O,
+                       size_t GlobalWidth) const override {
+    TranslationParser *TP = const_cast<TranslationParser *>(this);
+    llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
+                         [](const TranslationParser::OptionInfo *VT1,
+                            const TranslationParser::OptionInfo *VT2) {
+                           return VT1->Name.compare(VT2->Name);
+                         });
+    using llvm::cl::parser;
+    parser<const TranslateFunction *>::printOptionInfo(O, GlobalWidth);
+  }
+};
+
 int main(int argc, char **argv) {
   llvm::PrettyStackTraceProgram x(argc, argv);
   llvm::InitLLVM y(argc, argv);
 
-  llvm::cl::ParseCommandLineOptions(
-      argc, argv,
-      "MLIR translation driver\n\nRegistered translations:\n\t" +
-          registeredTranslationNames() + "\n");
-
-  auto translate = getTranslation(translationRequested);
-  if (!translate) {
-    llvm::errs() << "Translation requested '" << translationRequested
-                 << "' not registered\n";
-    return 1;
-  }
+  // Add flags for all the registered translations.
+  llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
+      translationRequested("", llvm::cl::desc("Translation to perform"),
+                           llvm::cl::Required);
+  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
 
   MLIRContext context;
   initializeMLIRContext(&context);
-  return translate(inputFilename, outputFilename, &context);
+  return (*translationRequested)(inputFilename, outputFilename, &context);
 }
diff --git a/mlir/tools/mlir-translate/mlir-translate.h b/mlir/tools/mlir-translate/mlir-translate.h
index df6694801f34..5979bea9bbbf 100644
--- a/mlir/tools/mlir-translate/mlir-translate.h
+++ b/mlir/tools/mlir-translate/mlir-translate.h
@@ -32,14 +32,14 @@ using TranslateFunction =
     std::function<bool(llvm::StringRef inputFilename,
                        llvm::StringRef oututFilename, MLIRContext *)>;
 
-// Use TranslateRegistration as a global initialiser that registers a
-// function and associates it with name. This requires that a command
-// has not been registered to a given name.
+// Use TranslateRegistration as a global initialiser that registers a function
+// and associates it with name. This requires that a translation has not been
+// registered to a given name.
 //
 // Usage:
 //
 //   // At namespace scope.
-//   static CommandRegistration Unused(&MySubCommand, [] { ... });
+//   static TranslateRegistration Unused(&MySubCommand, [] { ... });
 //
 struct TranslateRegistration {
   TranslateRegistration(llvm::StringRef name,