Use Operator class in OpDefinitionsGen. Cleanup NFC.

PiperOrigin-RevId: 227764826
This commit is contained in:
Jacques Pienaar 2019-01-03 15:53:54 -08:00 committed by jpienaar
parent 0ebc0ba72e
commit dde5bf234d
3 changed files with 101 additions and 167 deletions

View File

@ -63,7 +63,9 @@ public:
attribute_iterator attribute_begin();
attribute_iterator attribute_end();
llvm::iterator_range<attribute_iterator> getAttributes();
int getNumAttributes() { return attributes.size(); }
Attribute &getAttribute(int index) { return attributes[index]; }
const Attribute &getAttribute(int index) const { return attributes[index]; }
// Operations operand accessors.
struct Operand {
@ -76,6 +78,8 @@ public:
operand_iterator operand_end();
llvm::iterator_range<operand_iterator> getOperands();
Operand &getOperand(int index) { return operands[index]; }
const Operand &getOperand(int index) const { return operands[index]; }
int getNumOperands() { return operands.size(); }
// Operations argument accessors.
using Argument = llvm::PointerUnion<Attribute *, Operand *>;

View File

@ -112,18 +112,30 @@ void Operator::populateOperandsAndAttributes() {
if (!givenName)
PrintFatalError(argDef->getLoc(), "attributes must be named");
bool isDerived = argDef->isSubClassOf(derivedAttrClass);
// Update start of derived attributes or ensure that non-derived and derived
// attributes are not interleaved.
if (derivedAttrStart == -1) {
if (isDerived)
derivedAttrStart = i;
} else {
if (!isDerived)
PrintFatalError(
def.getLoc(),
"derived attributes have to follow non-derived attributes");
}
if (isDerived)
PrintFatalError(def.getLoc(),
"derived attributes not allowed in argument list");
attributes.push_back({givenName, argDef, isDerived});
}
// Derived attributes.
derivedAttrStart = i;
for (const auto &val : def.getValues()) {
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
if (!record->isSubClassOf(attrClass))
continue;
if (!record->isSubClassOf(derivedAttrClass))
PrintFatalError(def.getLoc(),
"unexpected Attr where only DerivedAttr is allowed");
if (record->getClasses().size() != 1) {
PrintFatalError(
def.getLoc(),
"unsupported attribute modelling, only single class expected");
}
attributes.push_back({cast<llvm::StringInit>(val.getNameInit()),
cast<DefInit>(val.getValue())->getDef(),
/*isDerived=*/true});
}
}
}

View File

@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
@ -29,6 +30,7 @@
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
static const char *const generatedArgName = "_arg";
@ -64,6 +66,18 @@ static inline StringRef getAsStringOrDefault(const Record &record,
: defaultVal;
}
static std::string getAttributeName(const Operator::Attribute &attr) {
return attr.name->getAsUnquotedString();
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (operand.name)
return operand.name->getAsUnquotedString();
else
return formatv("{0}_{1}", generatedArgName, index);
}
namespace {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
@ -114,122 +128,41 @@ public:
private:
OpEmitter(const Record &def, raw_ostream &os);
// Populates the operands and attributes.
void getOperandsAndAttributes();
// Returns the class name of the op.
StringRef cppClassName() const;
// Invokes the given function over all the namespaces of the class.
void mapOverClassNamespaces(function_ref<void(StringRef)> fn) const;
// Returns the operation name.
StringRef getOperationName() const;
void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
// The record corresponding to the op.
const Record &def;
const RecordKeeper &recordKeeper;
// Record of Attr class.
Record *attrClass;
// Type of DerivedAttr.
const RecordRecTy *derivedAttrType;
// The name of the op split around '_'.
SmallVector<StringRef, 2> splittedDefName;
// The operands of the op.
SmallVector<std::pair<std::string, const DefInit *>, 4> operands;
// The attributes of the op.
SmallVector<std::pair<std::string, const DefInit *>, 4> attrs;
SmallVector<std::pair<const RecordVal *, const Record *>, 4> derivedAttrs;
// The operator being emitted.
Operator op;
raw_ostream &os;
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Record &def, raw_ostream &os)
: def(def), recordKeeper(def.getRecords()),
attrClass(recordKeeper.getClass("Attr")),
derivedAttrType(recordKeeper.getClass("DerivedAttr")->getType()), os(os) {
SplitString(def.getName(), splittedDefName, "_");
getOperandsAndAttributes();
}
: def(def), op(def), os(os) {}
StringRef OpEmitter::cppClassName() const { return splittedDefName.back(); }
StringRef OpEmitter::getOperationName() const {
return def.getValueAsString("opName");
}
void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) const {
void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) {
auto &splittedDefName = op.getSplitDefName();
for (auto it = splittedDefName.begin(), e = std::prev(splittedDefName.end());
it != e; ++it)
fn(*it);
}
void OpEmitter::getOperandsAndAttributes() {
DagInit *argumentValues = def.getValueAsDag("arguments");
for (unsigned i = 0, e = argumentValues->getNumArgs(); i != e; ++i) {
auto arg = argumentValues->getArg(i);
auto givenName = argumentValues->getArgName(i);
DefInit *argDef = dyn_cast<DefInit>(arg);
if (!argDef)
PrintFatalError(def.getLoc(),
"unexpected type for " + Twine(i) + "th argument");
// Handle attribute.
if (argDef->getDef()->isSubClassOf(attrClass)) {
if (!givenName)
PrintFatalError(argDef->getDef()->getLoc(), "attributes must be named");
attrs.emplace_back(givenName->getValue(), argDef);
continue;
}
// Handle operands.
std::string name;
if (givenName)
name = givenName->getValue();
else
name = formatv("{0}_{1}", generatedArgName, i);
operands.emplace_back(name, argDef);
}
// Derived attributes.
for (const auto &val : def.getValues()) {
if (auto *record = dyn_cast<RecordRecTy>(val.getType())) {
if (record->typeIsA(derivedAttrType)) {
if (record->getClasses().size() != 1) {
PrintFatalError(
def.getLoc(),
"unsupported attribute modelling, only single class expected");
}
derivedAttrs.emplace_back(&val, *record->getClasses().begin());
continue;
}
if (record->isSubClassOf(attrClass))
PrintFatalError(def.getLoc(),
"unexpected Attr where only DerivedAttr is allowed");
}
}
}
void OpEmitter::emit(const Record &def, raw_ostream &os) {
OpEmitter emitter(def, os);
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "\nnamespace " << ns << "{\n"; });
os << "class " << emitter.cppClassName() << " : public Op<"
<< emitter.cppClassName();
os << formatv("class {0} : public Op<{0}", emitter.op.cppClassName());
emitter.emitTraits();
os << "> {\npublic:\n";
// Build operation name.
os << " static StringRef getOperationName() { return \""
<< emitter.getOperationName() << "\"; };\n";
<< emitter.op.getOperationName() << "\"; };\n";
emitter.emitNamedOperands();
emitter.emitBuilder();
@ -240,49 +173,34 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
emitter.emitCanonicalizationPatterns();
os << "private:\n friend class ::mlir::OperationInst;\n";
os << " explicit " << emitter.cppClassName()
<< "(const OperationInst* state) : Op(state) {}\n";
os << "};\n";
os << " explicit " << emitter.op.cppClassName()
<< "(const OperationInst* state) : Op(state) {}\n};\n";
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
}
void OpEmitter::emitAttrGetters() {
for (const auto &pair : derivedAttrs) {
auto &val = *pair.first;
for (auto &attr : op.getAttributes()) {
auto name = getAttributeName(attr);
auto *def = attr.record;
// Emit the derived attribute body.
if (auto defInit = dyn_cast<DefInit>(val.getValue())) {
if (defInit->getType()->typeIsA(derivedAttrType)) {
auto *def = defInit->getDef();
os << " " << def->getValueAsString("returnType").trim() << ' '
<< val.getName() << "() const {" << def->getValueAsString("body")
<< " }\n";
continue;
}
}
}
for (const auto &pair : attrs) {
auto &name = pair.first;
auto &attr = *pair.second->getDef();
// Emit normal emitter.
if (!hasStringAttribute(attr, "storageType")) {
// Handle the base case where there is no storage type specified.
os << " Attribute " << name << "() const {\n return getAttr(\""
<< name << "\");\n }\n";
if (attr.isDerived) {
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
<< "() const {" << def->getValueAsString("body") << " }\n";
continue;
}
os << " " << attr.getValueAsString("returnType").trim() << ' ' << name
// Emit normal emitter.
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
<< "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal =
formatv("this->getAttrOfType<{0}>(\"{1}\")",
attr.getValueAsString("storageType").trim(), name);
def->getValueAsString("storageType").trim(), name);
os << " return "
<< formatv(attr.getValueAsString("convertFromStorage"), attrVal)
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
<< ";\n }\n";
}
}
@ -295,10 +213,10 @@ void OpEmitter::emitNamedOperands() {
return this->getInstruction()->getOperand({1});
}
)";
for (int i = 0, e = operands.size(); i != e; ++i) {
const auto &op = operands[i];
if (!StringRef(op.first).startswith(generatedArgName))
os << formatv(operandMethods, op.first, i);
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name)
os << formatv(operandMethods, operand.name->getAsUnquotedString(), i);
}
}
@ -328,15 +246,17 @@ void OpEmitter::emitBuilder() {
os << ", Type returnType" << i;
// Emit parameters for all operands
for (const auto &pair : operands)
os << ", Value* " << pair.first;
for (int i = 0, e = op.getNumOperands(); i != e; ++i)
os << ", Value* " << getArgumentName(op, i);
// Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes
for (const auto &pair : attrs) {
const Record &attr = *pair.second->getDef();
os << ", " << getAsStringOrDefault(attr, "storageType", "Attribute").trim()
<< ' ' << pair.first;
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
break;
const Record &def = *attr.record;
os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim()
<< ' ' << getAttributeName(attr);
}
os << ") {\n";
@ -350,19 +270,18 @@ void OpEmitter::emitBuilder() {
}
// Push all operands to the result
if (!operands.empty()) {
os << " result->addOperands({" << operands.front().first;
for (auto it = operands.begin() + 1, e = operands.end(); it != e; ++it)
os << ", " << it->first;
if (op.getNumOperands() > 0) {
os << " result->addOperands({" << getArgumentName(op, 0);
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
os << ", " << getArgumentName(op, i);
os << "});\n";
}
// Push all attributes to the result
for (const auto &pair : attrs) {
StringRef name = pair.first;
os << " result->addAttribute(\"" << name << "\", " << name << ");\n";
}
for (const auto &attr : op.getAttributes())
if (!attr.isDerived)
os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
getAttributeName(attr));
os << " }\n";
// 2. Aggregated parameters
@ -378,16 +297,16 @@ void OpEmitter::emitBuilder() {
<< " result->addTypes(resultTypes);\n";
// Operands
os << " assert(args.size() == " << operands.size()
os << " assert(args.size() == " << op.getNumOperands()
<< "u && \"mismatched number of parameters\");\n"
<< " result->addOperands(args);\n\n";
// Attributes
if (attrs.empty()) {
if (op.getNumAttributes() > 0) {
os << " assert(!attributes.size() && \"no attributes expected\");\n"
<< " }\n";
} else {
os << " assert(attributes.size() >= " << attrs.size()
os << " assert(attributes.size() >= " << op.getNumAttributes()
<< "u && \"not enough attributes\");\n"
<< " for (const auto& pair : attributes)\n"
<< " result->addAttribute(pair.first, pair.second);\n"
@ -424,14 +343,17 @@ void OpEmitter::emitVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
if (!hasCustomVerify && attrs.empty())
if (!hasCustomVerify && op.getNumAttributes() == 0)
return;
os << " bool verify() const {\n";
// Verify the attributes have the correct type.
for (const auto attr : attrs) {
auto name = attr.first;
if (!hasStringAttribute(*attr.second->getDef(), "storageType")) {
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
continue;
auto name = getAttributeName(attr);
if (!hasStringAttribute(*attr.record, "storageType")) {
os << " if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
@ -439,10 +361,10 @@ void OpEmitter::emitVerifier() {
}
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.second->getDef()->getValueAsString("storageType").trim()
<< attr.record->getValueAsString("storageType").trim()
<< ">()) return emitOpError(\"requires "
<< attr.second->getDef()->getValueAsString("returnType").trim()
<< " attribute '" << name << "'\");\n";
<< attr.record->getValueAsString("returnType").trim() << " attribute '"
<< name << "'\");\n";
}
if (hasCustomVerify)
@ -486,13 +408,13 @@ void OpEmitter::emitTraits() {
}
}
if ((hasVariadicOperands || hasAtLeastNOperands) && !operands.empty()) {
if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) {
PrintFatalError(def.getLoc(),
"Operands number definition is not consistent.");
}
// Add operand size trait if defined explicitly.
switch (operands.size()) {
switch (op.getNumOperands()) {
case 0:
if (!hasVariadicOperands && !hasAtLeastNOperands)
os << ", OpTrait::ZeroOperands";
@ -501,7 +423,7 @@ void OpEmitter::emitTraits() {
os << ", OpTrait::OneOperand";
break;
default:
os << ", OpTrait::NOperands<" << operands.size() << ">::Impl";
os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl";
break;
}
@ -517,8 +439,7 @@ void OpEmitter::emitTraits() {
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const RecordKeeper &recordKeeper,
const std::vector<Record *> &defs, raw_ostream &os) {
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs)
OpEmitter::emit(*def, os);
@ -532,10 +453,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
for (auto &def : defs) {
if (!first)
os << ",";
SmallVector<StringRef, 2> splittedDefName;
SplitString(def->getName(), splittedDefName, "_");
os << join(splittedDefName, "::");
os << Operator(def).qualifiedCppClassName();
first = false;
}
}
@ -546,7 +464,7 @@ static void emitOpDefinitions(const RecordKeeper &recordKeeper,
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpList(defs, os);
emitOpClasses(recordKeeper, defs, os);
emitOpClasses(defs, os);
}
static void emitOpDefFile(const RecordKeeper &recordKeeper, raw_ostream &os) {