forked from OSchip/llvm-project
[TableGen] Fix support for ops whose names have a leading underscore
TensorFlow ops have the convention to use leading underscore to denote internal ops. -- PiperOrigin-RevId: 243627723
This commit is contained in:
parent
09c053bfd0
commit
8bb8351710
|
@ -53,17 +53,20 @@ public:
|
|||
// Returns the operation name.
|
||||
StringRef getOperationName() const;
|
||||
|
||||
// Returns dialect name of the op.
|
||||
// Returns this op's dialect name.
|
||||
StringRef getDialectName() const;
|
||||
|
||||
// Returns the C++ class name of the op.
|
||||
// Returns this op's C++ class name.
|
||||
StringRef getCppClassName() const;
|
||||
|
||||
// Returns the C++ class name of the op with namespace added.
|
||||
std::string getQualCppClassName() const;
|
||||
// Returns the qualified C++ class name for the given TableGen def `name`.
|
||||
// The first `_` in `name` is treated as separating the dialect namespace
|
||||
// and the op class name if the dialect namespace is not empty. Otherwise,
|
||||
// if `name` starts with a `_`, the `_` is considered as part the class name.
|
||||
static std::string getQualCppClassName(StringRef name);
|
||||
|
||||
// Returns the TableGen definition name split around '_'.
|
||||
ArrayRef<StringRef> getSplitDefName() const;
|
||||
// Returns this op's C++ class name prefixed with dialect namespace.
|
||||
std::string getQualCppClassName() const;
|
||||
|
||||
// Returns the number of results this op produces.
|
||||
int getNumResults() const;
|
||||
|
@ -145,8 +148,11 @@ private:
|
|||
// Populates the vectors containing operands, attributes, results and traits.
|
||||
void populateOpStructure();
|
||||
|
||||
// The name of the op split around '_'.
|
||||
SmallVector<StringRef, 2> splittedDefName;
|
||||
// The dialect name of the op.
|
||||
StringRef dialectName;
|
||||
|
||||
// The unqualified C++ class name of the op.
|
||||
StringRef cppClassName;
|
||||
|
||||
// The operands of the op.
|
||||
SmallVector<NamedTypeConstraint, 4> operands;
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "mlir/TableGen/OpTrait.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "mlir/TableGen/Type.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
@ -35,27 +34,36 @@ using llvm::DefInit;
|
|||
using llvm::Record;
|
||||
|
||||
tblgen::Operator::Operator(const llvm::Record &def) : def(def) {
|
||||
SplitString(def.getName(), splittedDefName, "_");
|
||||
populateOpStructure();
|
||||
}
|
||||
std::tie(dialectName, cppClassName) = def.getName().split('_');
|
||||
if (dialectName.empty()) {
|
||||
// Class name with a leading underscore and without dialect name
|
||||
cppClassName = def.getName();
|
||||
} else if (cppClassName.empty()) {
|
||||
// Class name without dialect name
|
||||
std::swap(dialectName, cppClassName);
|
||||
}
|
||||
|
||||
ArrayRef<StringRef> tblgen::Operator::getSplitDefName() const {
|
||||
return splittedDefName;
|
||||
populateOpStructure();
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getOperationName() const {
|
||||
return def.getValueAsString("opName");
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getDialectName() const {
|
||||
return getSplitDefName().front();
|
||||
StringRef tblgen::Operator::getDialectName() const { return dialectName; }
|
||||
|
||||
StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
|
||||
|
||||
std::string tblgen::Operator::getQualCppClassName(StringRef name) {
|
||||
StringRef ns, cls;
|
||||
std::tie(ns, cls) = name.split('_');
|
||||
if (ns.empty() || cls.empty())
|
||||
return name;
|
||||
return (ns + "::" + cls).str();
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getCppClassName() const {
|
||||
return getSplitDefName().back();
|
||||
}
|
||||
std::string tblgen::Operator::getQualCppClassName() const {
|
||||
return llvm::join(getSplitDefName(), "::");
|
||||
return getQualCppClassName(def.getName());
|
||||
}
|
||||
|
||||
int tblgen::Operator::getNumResults() const {
|
||||
|
|
|
@ -47,3 +47,13 @@ def NS_AOp : Op<"a_op", [NoSideEffect]> {
|
|||
// CHECK: LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
|
||||
// CHECK: bool fold(SmallVectorImpl<Value *> &results);
|
||||
// CHECK: };
|
||||
|
||||
def NS__BOp : Op<"_b_op", []>;
|
||||
|
||||
// CHECK-LABEL: NS::_BOp declarations
|
||||
// CHECK: class _BOp : public Op<_BOp
|
||||
|
||||
def _COp : Op<"_c_op", []>;
|
||||
|
||||
// CHECK-LABEL: _COp declarations
|
||||
// CHECK: class _COp : public Op<_COp
|
||||
|
|
|
@ -58,13 +58,6 @@ static inline bool hasStringAttribute(const Record &record,
|
|||
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
|
||||
}
|
||||
|
||||
// Returns the given `op`'s qualified C++ class name.
|
||||
static std::string getOpQualClassName(const Record &op) {
|
||||
SmallVector<StringRef, 2> splittedName;
|
||||
llvm::SplitString(op.getName(), splittedName, "_");
|
||||
return llvm::join(splittedName, "::");
|
||||
}
|
||||
|
||||
static std::string getArgumentName(const Operator &op, int index) {
|
||||
const auto &operand = op.getOperand(index);
|
||||
if (!operand.name.empty())
|
||||
|
@ -938,10 +931,14 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
|
|||
IfDefScope scope("GET_OP_CLASSES", os);
|
||||
for (auto *def : defs) {
|
||||
if (emitDecl) {
|
||||
os << formatv(opCommentHeader, getOpQualClassName(*def), "declarations");
|
||||
os << formatv(opCommentHeader,
|
||||
Operator::getQualCppClassName(def->getName()),
|
||||
"declarations");
|
||||
OpEmitter::emitDecl(*def, os);
|
||||
} else {
|
||||
os << formatv(opCommentHeader, getOpQualClassName(*def), "definitions");
|
||||
os << formatv(opCommentHeader,
|
||||
Operator::getQualCppClassName(def->getName()),
|
||||
"definitions");
|
||||
OpEmitter::emitDef(*def, os);
|
||||
}
|
||||
}
|
||||
|
@ -952,7 +949,10 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
|||
IfDefScope scope("GET_OP_LIST", os);
|
||||
|
||||
interleave(
|
||||
defs, [&os](Record *def) { os << getOpQualClassName(*def); },
|
||||
defs,
|
||||
[&os](Record *def) {
|
||||
os << Operator::getQualCppClassName(def->getName());
|
||||
},
|
||||
[&os]() { os << ",\n"; });
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue