Various tiny refinements over TableGen Operator class

Use "native" vs "derived" to differentiate attributes on ops: native ones
are specified when creating the op as a part of defining the op, while
derived ones are computed from properties of the op.

PiperOrigin-RevId: 228186962
This commit is contained in:
Lei Zhang 2019-01-07 10:09:34 -08:00 committed by jpienaar
parent 65fc8643ec
commit f8bbe5deca
2 changed files with 35 additions and 25 deletions

View File

@ -15,7 +15,7 @@
// limitations under the License.
// =============================================================================
//
// Operator wrapper to simplifying using Record corresponding to Operator.
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
//===----------------------------------------------------------------------===//
@ -36,6 +36,9 @@ class StringInit;
namespace mlir {
// Wrapper class that contains a MLIR op's information (e.g., operands,
// atributes) defined in TableGen and provides helper methods for
// accessing them.
class Operator {
public:
explicit Operator(const llvm::Record &def);
@ -44,16 +47,15 @@ public:
// Returns the operation name.
StringRef getOperationName() const;
// Returns the def name split around '_'.
const SmallVectorImpl<StringRef> &getSplitDefName();
// Returns the TableGen definition name split around '_'.
const SmallVectorImpl<StringRef> &getSplitDefName() const;
// Returns the class name of the op.
StringRef cppClassName();
// Returns the C++ class name of the op.
StringRef cppClassName() const;
// Returns the class name of the op with namespace added.
std::string qualifiedCppClassName();
// Returns the C++ class name of the op with namespace added.
std::string qualifiedCppClassName() const;
// Operations attribute accessors.
struct Attribute {
std::string getName() const;
StringRef getReturnType() const;
@ -64,15 +66,17 @@ public:
bool isDerived;
};
// Op attribute interators.
using attribute_iterator = Attribute *;
attribute_iterator attribute_begin();
attribute_iterator attribute_end();
llvm::iterator_range<attribute_iterator> getAttributes();
int getNumAttributes() { return attributes.size(); }
// Op attribute accessors.
int getNumAttributes() const { return attributes.size(); }
Attribute &getAttribute(int index) { return attributes[index]; }
const Attribute &getAttribute(int index) const { return attributes[index]; }
// Operations operand accessors.
struct Operand {
bool hasMatcher() const;
// Return the matcher template for the operand type.
@ -82,15 +86,18 @@ public:
llvm::DefInit *defInit;
};
// Op operand iterators.
using operand_iterator = Operand *;
operand_iterator operand_begin();
operand_iterator operand_end();
llvm::iterator_range<operand_iterator> getOperands();
// Op operand accessors.
int getNumOperands() const { return operands.size(); }
Operand &getOperand(int index) { return operands[index]; }
const Operand &getOperand(int index) const { return operands[index]; }
int getNumOperands() { return operands.size(); }
// Operations argument accessors.
// Op argument (attribute or operand) accessors.
using Argument = llvm::PointerUnion<Attribute *, Operand *>;
Argument getArg(int index);
StringRef getArgName(int index) const;
@ -109,12 +116,15 @@ private:
// The attributes of the op.
SmallVector<Attribute, 4> attributes;
// The start of attributes.
int attrStart;
// The start of native attributes, which are specified when creating the op
// as a part of the op's definition.
int nativeAttrStart;
// The start of the derived attributes.
// The start of derived attributes, which are computed from properties of
// the op.
int derivedAttrStart;
// The TableGen definition of this op.
const llvm::Record &def;
};

View File

@ -15,7 +15,7 @@
// limitations under the License.
// =============================================================================
//
// Operator wrapper to simplifying using Record corresponding to Operator.
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
//===----------------------------------------------------------------------===//
@ -36,7 +36,7 @@ Operator::Operator(const llvm::Record &def) : def(def) {
populateOperandsAndAttributes();
}
const SmallVectorImpl<StringRef> &Operator::getSplitDefName() {
const SmallVectorImpl<StringRef> &Operator::getSplitDefName() const {
return splittedDefName;
}
@ -44,8 +44,8 @@ StringRef Operator::getOperationName() const {
return def.getValueAsString("opName");
}
StringRef Operator::cppClassName() { return getSplitDefName().back(); }
std::string Operator::qualifiedCppClassName() {
StringRef Operator::cppClassName() const { return getSplitDefName().back(); }
std::string Operator::qualifiedCppClassName() const {
return llvm::join(getSplitDefName(), "::");
}
@ -71,9 +71,9 @@ auto Operator::getOperands() -> llvm::iterator_range<operand_iterator> {
}
auto Operator::getArg(int index) -> Argument {
if (index < attrStart)
if (index < nativeAttrStart)
return {&operands[index]};
return {&attributes[index - attrStart]};
return {&attributes[index - nativeAttrStart]};
}
void Operator::populateOperandsAndAttributes() {
@ -82,7 +82,7 @@ void Operator::populateOperandsAndAttributes() {
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
derivedAttrStart = -1;
// The argument ordering is operands, non-derived attributes, derived
// The argument ordering is operands, native attributes, derived
// attributes.
DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned i = 0;
@ -100,8 +100,8 @@ void Operator::populateOperandsAndAttributes() {
operands.push_back(Operand{givenName, argDefInit});
}
// Handle attribute.
attrStart = i;
// Handle native attributes.
nativeAttrStart = i;
for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
auto arg = argumentValues->getArg(i);
auto givenName = argumentValues->getArgName(i);
@ -119,7 +119,7 @@ void Operator::populateOperandsAndAttributes() {
attributes.push_back({givenName, argDef, isDerived});
}
// Derived attributes.
// Handle derived attributes.
derivedAttrStart = i;
for (const auto &val : def.getValues()) {
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {