[TableGen] Fix support for ops whose names have a leading underscore

TensorFlow ops have the convention to use leading underscore to denote
    internal ops.

--

PiperOrigin-RevId: 243627723
This commit is contained in:
Lei Zhang 2019-04-15 09:13:22 -07:00 committed by Mehdi Amini
parent 09c053bfd0
commit 8bb8351710
4 changed files with 54 additions and 30 deletions

View File

@ -53,17 +53,20 @@ public:
// Returns the operation name.
StringRef getOperationName() const;
// Returns dialect name of the op.
// Returns this op's dialect name.
StringRef getDialectName() const;
// Returns the C++ class name of the op.
// Returns this op's C++ class name.
StringRef getCppClassName() const;
// Returns the C++ class name of the op with namespace added.
std::string getQualCppClassName() const;
// Returns the qualified C++ class name for the given TableGen def `name`.
// The first `_` in `name` is treated as separating the dialect namespace
// and the op class name if the dialect namespace is not empty. Otherwise,
// if `name` starts with a `_`, the `_` is considered as part the class name.
static std::string getQualCppClassName(StringRef name);
// Returns the TableGen definition name split around '_'.
ArrayRef<StringRef> getSplitDefName() const;
// Returns this op's C++ class name prefixed with dialect namespace.
std::string getQualCppClassName() const;
// Returns the number of results this op produces.
int getNumResults() const;
@ -145,8 +148,11 @@ private:
// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();
// The name of the op split around '_'.
SmallVector<StringRef, 2> splittedDefName;
// The dialect name of the op.
StringRef dialectName;
// The unqualified C++ class name of the op.
StringRef cppClassName;
// The operands of the op.
SmallVector<NamedTypeConstraint, 4> operands;

View File

@ -23,7 +23,6 @@
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -35,27 +34,36 @@ using llvm::DefInit;
using llvm::Record;
tblgen::Operator::Operator(const llvm::Record &def) : def(def) {
SplitString(def.getName(), splittedDefName, "_");
populateOpStructure();
}
std::tie(dialectName, cppClassName) = def.getName().split('_');
if (dialectName.empty()) {
// Class name with a leading underscore and without dialect name
cppClassName = def.getName();
} else if (cppClassName.empty()) {
// Class name without dialect name
std::swap(dialectName, cppClassName);
}
ArrayRef<StringRef> tblgen::Operator::getSplitDefName() const {
return splittedDefName;
populateOpStructure();
}
StringRef tblgen::Operator::getOperationName() const {
return def.getValueAsString("opName");
}
StringRef tblgen::Operator::getDialectName() const {
return getSplitDefName().front();
StringRef tblgen::Operator::getDialectName() const { return dialectName; }
StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
std::string tblgen::Operator::getQualCppClassName(StringRef name) {
StringRef ns, cls;
std::tie(ns, cls) = name.split('_');
if (ns.empty() || cls.empty())
return name;
return (ns + "::" + cls).str();
}
StringRef tblgen::Operator::getCppClassName() const {
return getSplitDefName().back();
}
std::string tblgen::Operator::getQualCppClassName() const {
return llvm::join(getSplitDefName(), "::");
return getQualCppClassName(def.getName());
}
int tblgen::Operator::getNumResults() const {

View File

@ -47,3 +47,13 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> {
// CHECK: LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
// CHECK: bool fold(SmallVectorImpl<Value *> &results);
// CHECK: };
def NS__BOp : Op<"_b_op", []>;
// CHECK-LABEL: NS::_BOp declarations
// CHECK: class _BOp : public Op<_BOp
def _COp : Op<"_c_op", []>;
// CHECK-LABEL: _COp declarations
// CHECK: class _COp : public Op<_COp

View File

@ -58,13 +58,6 @@ static inline bool hasStringAttribute(const Record &record,
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
// Returns the given `op`'s qualified C++ class name.
static std::string getOpQualClassName(const Record &op) {
SmallVector<StringRef, 2> splittedName;
llvm::SplitString(op.getName(), splittedName, "_");
return llvm::join(splittedName, "::");
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
@ -938,10 +931,14 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs) {
if (emitDecl) {
os << formatv(opCommentHeader, getOpQualClassName(*def), "declarations");
os << formatv(opCommentHeader,
Operator::getQualCppClassName(def->getName()),
"declarations");
OpEmitter::emitDecl(*def, os);
} else {
os << formatv(opCommentHeader, getOpQualClassName(*def), "definitions");
os << formatv(opCommentHeader,
Operator::getQualCppClassName(def->getName()),
"definitions");
OpEmitter::emitDef(*def, os);
}
}
@ -952,7 +949,10 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
defs, [&os](Record *def) { os << getOpQualClassName(*def); },
defs,
[&os](Record *def) {
os << Operator::getQualCppClassName(def->getName());
},
[&os]() { os << ",\n"; });
}