From 8bb8351710b91d1a89a774ba2c46d75f83c432e2 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 15 Apr 2019 09:13:22 -0700 Subject: [PATCH] [TableGen] Fix support for ops whose names have a leading underscore TensorFlow ops have the convention to use leading underscore to denote internal ops. -- PiperOrigin-RevId: 243627723 --- mlir/include/mlir/TableGen/Operator.h | 22 ++++++++------ mlir/lib/TableGen/Operator.cpp | 32 +++++++++++++-------- mlir/test/mlir-tblgen/op-decl.td | 10 +++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 20 ++++++------- 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 64509117db3a..723362611e8a 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -53,17 +53,20 @@ public: // Returns the operation name. StringRef getOperationName() const; - // Returns dialect name of the op. + // Returns this op's dialect name. StringRef getDialectName() const; - // Returns the C++ class name of the op. + // Returns this op's C++ class name. StringRef getCppClassName() const; - // Returns the C++ class name of the op with namespace added. - std::string getQualCppClassName() const; + // Returns the qualified C++ class name for the given TableGen def `name`. + // The first `_` in `name` is treated as separating the dialect namespace + // and the op class name if the dialect namespace is not empty. Otherwise, + // if `name` starts with a `_`, the `_` is considered as part the class name. + static std::string getQualCppClassName(StringRef name); - // Returns the TableGen definition name split around '_'. - ArrayRef getSplitDefName() const; + // Returns this op's C++ class name prefixed with dialect namespace. + std::string getQualCppClassName() const; // Returns the number of results this op produces. int getNumResults() const; @@ -145,8 +148,11 @@ private: // Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); - // The name of the op split around '_'. - SmallVector splittedDefName; + // The dialect name of the op. + StringRef dialectName; + + // The unqualified C++ class name of the op. + StringRef cppClassName; // The operands of the op. SmallVector operands; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 9c93105143d7..6f0f3ea61b72 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -23,7 +23,6 @@ #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -35,27 +34,36 @@ using llvm::DefInit; using llvm::Record; tblgen::Operator::Operator(const llvm::Record &def) : def(def) { - SplitString(def.getName(), splittedDefName, "_"); - populateOpStructure(); -} + std::tie(dialectName, cppClassName) = def.getName().split('_'); + if (dialectName.empty()) { + // Class name with a leading underscore and without dialect name + cppClassName = def.getName(); + } else if (cppClassName.empty()) { + // Class name without dialect name + std::swap(dialectName, cppClassName); + } -ArrayRef tblgen::Operator::getSplitDefName() const { - return splittedDefName; + populateOpStructure(); } StringRef tblgen::Operator::getOperationName() const { return def.getValueAsString("opName"); } -StringRef tblgen::Operator::getDialectName() const { - return getSplitDefName().front(); +StringRef tblgen::Operator::getDialectName() const { return dialectName; } + +StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } + +std::string tblgen::Operator::getQualCppClassName(StringRef name) { + StringRef ns, cls; + std::tie(ns, cls) = name.split('_'); + if (ns.empty() || cls.empty()) + return name; + return (ns + "::" + cls).str(); } -StringRef tblgen::Operator::getCppClassName() const { - return getSplitDefName().back(); -} std::string tblgen::Operator::getQualCppClassName() const { - return llvm::join(getSplitDefName(), "::"); + return getQualCppClassName(def.getName()); } int tblgen::Operator::getNumResults() const { diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 3b02e5e987d5..a6030f22b6f9 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -47,3 +47,13 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> { // CHECK: LogicalResult constantFold(ArrayRef operands, SmallVectorImpl &results, MLIRContext *context); // CHECK: bool fold(SmallVectorImpl &results); // CHECK: }; + +def NS__BOp : Op<"_b_op", []>; + +// CHECK-LABEL: NS::_BOp declarations +// CHECK: class _BOp : public Op<_BOp + +def _COp : Op<"_c_op", []>; + +// CHECK-LABEL: _COp declarations +// CHECK: class _COp : public Op<_COp diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b67a4f97bfa6..398fe785cc1f 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -58,13 +58,6 @@ static inline bool hasStringAttribute(const Record &record, return isa(valueInit) || isa(valueInit); } -// Returns the given `op`'s qualified C++ class name. -static std::string getOpQualClassName(const Record &op) { - SmallVector splittedName; - llvm::SplitString(op.getName(), splittedName, "_"); - return llvm::join(splittedName, "::"); -} - static std::string getArgumentName(const Operator &op, int index) { const auto &operand = op.getOperand(index); if (!operand.name.empty()) @@ -938,10 +931,14 @@ static void emitOpClasses(const std::vector &defs, raw_ostream &os, IfDefScope scope("GET_OP_CLASSES", os); for (auto *def : defs) { if (emitDecl) { - os << formatv(opCommentHeader, getOpQualClassName(*def), "declarations"); + os << formatv(opCommentHeader, + Operator::getQualCppClassName(def->getName()), + "declarations"); OpEmitter::emitDecl(*def, os); } else { - os << formatv(opCommentHeader, getOpQualClassName(*def), "definitions"); + os << formatv(opCommentHeader, + Operator::getQualCppClassName(def->getName()), + "definitions"); OpEmitter::emitDef(*def, os); } } @@ -952,7 +949,10 @@ static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); interleave( - defs, [&os](Record *def) { os << getOpQualClassName(*def); }, + defs, + [&os](Record *def) { + os << Operator::getQualCppClassName(def->getName()); + }, [&os]() { os << ",\n"; }); }