forked from OSchip/llvm-project
Use Operator class in OpDefinitionsGen. Cleanup NFC.
PiperOrigin-RevId: 227764826
This commit is contained in:
parent
0ebc0ba72e
commit
dde5bf234d
|
@ -63,7 +63,9 @@ public:
|
|||
attribute_iterator attribute_begin();
|
||||
attribute_iterator attribute_end();
|
||||
llvm::iterator_range<attribute_iterator> getAttributes();
|
||||
int getNumAttributes() { return attributes.size(); }
|
||||
Attribute &getAttribute(int index) { return attributes[index]; }
|
||||
const Attribute &getAttribute(int index) const { return attributes[index]; }
|
||||
|
||||
// Operations operand accessors.
|
||||
struct Operand {
|
||||
|
@ -76,6 +78,8 @@ public:
|
|||
operand_iterator operand_end();
|
||||
llvm::iterator_range<operand_iterator> getOperands();
|
||||
Operand &getOperand(int index) { return operands[index]; }
|
||||
const Operand &getOperand(int index) const { return operands[index]; }
|
||||
int getNumOperands() { return operands.size(); }
|
||||
|
||||
// Operations argument accessors.
|
||||
using Argument = llvm::PointerUnion<Attribute *, Operand *>;
|
||||
|
|
|
@ -112,18 +112,30 @@ void Operator::populateOperandsAndAttributes() {
|
|||
if (!givenName)
|
||||
PrintFatalError(argDef->getLoc(), "attributes must be named");
|
||||
bool isDerived = argDef->isSubClassOf(derivedAttrClass);
|
||||
|
||||
// Update start of derived attributes or ensure that non-derived and derived
|
||||
// attributes are not interleaved.
|
||||
if (derivedAttrStart == -1) {
|
||||
if (isDerived)
|
||||
derivedAttrStart = i;
|
||||
} else {
|
||||
if (!isDerived)
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
"derived attributes have to follow non-derived attributes");
|
||||
}
|
||||
if (isDerived)
|
||||
PrintFatalError(def.getLoc(),
|
||||
"derived attributes not allowed in argument list");
|
||||
attributes.push_back({givenName, argDef, isDerived});
|
||||
}
|
||||
|
||||
// Derived attributes.
|
||||
derivedAttrStart = i;
|
||||
for (const auto &val : def.getValues()) {
|
||||
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
|
||||
if (!record->isSubClassOf(attrClass))
|
||||
continue;
|
||||
if (!record->isSubClassOf(derivedAttrClass))
|
||||
PrintFatalError(def.getLoc(),
|
||||
"unexpected Attr where only DerivedAttr is allowed");
|
||||
|
||||
if (record->getClasses().size() != 1) {
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
"unsupported attribute modelling, only single class expected");
|
||||
}
|
||||
attributes.push_back({cast<llvm::StringInit>(val.getNameInit()),
|
||||
cast<DefInit>(val.getValue())->getDef(),
|
||||
/*isDerived=*/true});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
|
@ -29,6 +30,7 @@
|
|||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
static const char *const generatedArgName = "_arg";
|
||||
|
||||
|
@ -64,6 +66,18 @@ static inline StringRef getAsStringOrDefault(const Record &record,
|
|||
: defaultVal;
|
||||
}
|
||||
|
||||
static std::string getAttributeName(const Operator::Attribute &attr) {
|
||||
return attr.name->getAsUnquotedString();
|
||||
}
|
||||
|
||||
static std::string getArgumentName(const Operator &op, int index) {
|
||||
const auto &operand = op.getOperand(index);
|
||||
if (operand.name)
|
||||
return operand.name->getAsUnquotedString();
|
||||
else
|
||||
return formatv("{0}_{1}", generatedArgName, index);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Simple RAII helper for defining ifdef-undef-endif scopes.
|
||||
class IfDefScope {
|
||||
|
@ -114,122 +128,41 @@ public:
|
|||
private:
|
||||
OpEmitter(const Record &def, raw_ostream &os);
|
||||
|
||||
// Populates the operands and attributes.
|
||||
void getOperandsAndAttributes();
|
||||
|
||||
// Returns the class name of the op.
|
||||
StringRef cppClassName() const;
|
||||
|
||||
// Invokes the given function over all the namespaces of the class.
|
||||
void mapOverClassNamespaces(function_ref<void(StringRef)> fn) const;
|
||||
|
||||
// Returns the operation name.
|
||||
StringRef getOperationName() const;
|
||||
void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
|
||||
|
||||
// The record corresponding to the op.
|
||||
const Record &def;
|
||||
|
||||
const RecordKeeper &recordKeeper;
|
||||
|
||||
// Record of Attr class.
|
||||
Record *attrClass;
|
||||
|
||||
// Type of DerivedAttr.
|
||||
const RecordRecTy *derivedAttrType;
|
||||
|
||||
// The name of the op split around '_'.
|
||||
SmallVector<StringRef, 2> splittedDefName;
|
||||
|
||||
// The operands of the op.
|
||||
SmallVector<std::pair<std::string, const DefInit *>, 4> operands;
|
||||
|
||||
// The attributes of the op.
|
||||
SmallVector<std::pair<std::string, const DefInit *>, 4> attrs;
|
||||
SmallVector<std::pair<const RecordVal *, const Record *>, 4> derivedAttrs;
|
||||
// The operator being emitted.
|
||||
Operator op;
|
||||
|
||||
raw_ostream &os;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
OpEmitter::OpEmitter(const Record &def, raw_ostream &os)
|
||||
: def(def), recordKeeper(def.getRecords()),
|
||||
attrClass(recordKeeper.getClass("Attr")),
|
||||
derivedAttrType(recordKeeper.getClass("DerivedAttr")->getType()), os(os) {
|
||||
SplitString(def.getName(), splittedDefName, "_");
|
||||
getOperandsAndAttributes();
|
||||
}
|
||||
: def(def), op(def), os(os) {}
|
||||
|
||||
StringRef OpEmitter::cppClassName() const { return splittedDefName.back(); }
|
||||
|
||||
StringRef OpEmitter::getOperationName() const {
|
||||
return def.getValueAsString("opName");
|
||||
}
|
||||
|
||||
void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) const {
|
||||
void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) {
|
||||
auto &splittedDefName = op.getSplitDefName();
|
||||
for (auto it = splittedDefName.begin(), e = std::prev(splittedDefName.end());
|
||||
it != e; ++it)
|
||||
fn(*it);
|
||||
}
|
||||
|
||||
void OpEmitter::getOperandsAndAttributes() {
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
for (unsigned i = 0, e = argumentValues->getNumArgs(); i != e; ++i) {
|
||||
auto arg = argumentValues->getArg(i);
|
||||
auto givenName = argumentValues->getArgName(i);
|
||||
DefInit *argDef = dyn_cast<DefInit>(arg);
|
||||
if (!argDef)
|
||||
PrintFatalError(def.getLoc(),
|
||||
"unexpected type for " + Twine(i) + "th argument");
|
||||
|
||||
// Handle attribute.
|
||||
if (argDef->getDef()->isSubClassOf(attrClass)) {
|
||||
if (!givenName)
|
||||
PrintFatalError(argDef->getDef()->getLoc(), "attributes must be named");
|
||||
attrs.emplace_back(givenName->getValue(), argDef);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle operands.
|
||||
std::string name;
|
||||
if (givenName)
|
||||
name = givenName->getValue();
|
||||
else
|
||||
name = formatv("{0}_{1}", generatedArgName, i);
|
||||
operands.emplace_back(name, argDef);
|
||||
}
|
||||
|
||||
// Derived attributes.
|
||||
for (const auto &val : def.getValues()) {
|
||||
if (auto *record = dyn_cast<RecordRecTy>(val.getType())) {
|
||||
if (record->typeIsA(derivedAttrType)) {
|
||||
if (record->getClasses().size() != 1) {
|
||||
PrintFatalError(
|
||||
def.getLoc(),
|
||||
"unsupported attribute modelling, only single class expected");
|
||||
}
|
||||
derivedAttrs.emplace_back(&val, *record->getClasses().begin());
|
||||
continue;
|
||||
}
|
||||
if (record->isSubClassOf(attrClass))
|
||||
PrintFatalError(def.getLoc(),
|
||||
"unexpected Attr where only DerivedAttr is allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
||||
OpEmitter emitter(def, os);
|
||||
|
||||
emitter.mapOverClassNamespaces(
|
||||
[&os](StringRef ns) { os << "\nnamespace " << ns << "{\n"; });
|
||||
os << "class " << emitter.cppClassName() << " : public Op<"
|
||||
<< emitter.cppClassName();
|
||||
os << formatv("class {0} : public Op<{0}", emitter.op.cppClassName());
|
||||
emitter.emitTraits();
|
||||
os << "> {\npublic:\n";
|
||||
|
||||
// Build operation name.
|
||||
os << " static StringRef getOperationName() { return \""
|
||||
<< emitter.getOperationName() << "\"; };\n";
|
||||
<< emitter.op.getOperationName() << "\"; };\n";
|
||||
|
||||
emitter.emitNamedOperands();
|
||||
emitter.emitBuilder();
|
||||
|
@ -240,49 +173,34 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
|||
emitter.emitCanonicalizationPatterns();
|
||||
|
||||
os << "private:\n friend class ::mlir::OperationInst;\n";
|
||||
os << " explicit " << emitter.cppClassName()
|
||||
<< "(const OperationInst* state) : Op(state) {}\n";
|
||||
os << "};\n";
|
||||
os << " explicit " << emitter.op.cppClassName()
|
||||
<< "(const OperationInst* state) : Op(state) {}\n};\n";
|
||||
emitter.mapOverClassNamespaces(
|
||||
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
|
||||
}
|
||||
|
||||
void OpEmitter::emitAttrGetters() {
|
||||
for (const auto &pair : derivedAttrs) {
|
||||
auto &val = *pair.first;
|
||||
for (auto &attr : op.getAttributes()) {
|
||||
auto name = getAttributeName(attr);
|
||||
auto *def = attr.record;
|
||||
|
||||
// Emit the derived attribute body.
|
||||
if (auto defInit = dyn_cast<DefInit>(val.getValue())) {
|
||||
if (defInit->getType()->typeIsA(derivedAttrType)) {
|
||||
auto *def = defInit->getDef();
|
||||
os << " " << def->getValueAsString("returnType").trim() << ' '
|
||||
<< val.getName() << "() const {" << def->getValueAsString("body")
|
||||
<< " }\n";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &pair : attrs) {
|
||||
auto &name = pair.first;
|
||||
auto &attr = *pair.second->getDef();
|
||||
// Emit normal emitter.
|
||||
if (!hasStringAttribute(attr, "storageType")) {
|
||||
// Handle the base case where there is no storage type specified.
|
||||
os << " Attribute " << name << "() const {\n return getAttr(\""
|
||||
<< name << "\");\n }\n";
|
||||
if (attr.isDerived) {
|
||||
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
|
||||
<< "() const {" << def->getValueAsString("body") << " }\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
os << " " << attr.getValueAsString("returnType").trim() << ' ' << name
|
||||
// Emit normal emitter.
|
||||
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
|
||||
<< "() const {\n";
|
||||
|
||||
// Return the queried attribute with the correct return type.
|
||||
std::string attrVal =
|
||||
formatv("this->getAttrOfType<{0}>(\"{1}\")",
|
||||
attr.getValueAsString("storageType").trim(), name);
|
||||
def->getValueAsString("storageType").trim(), name);
|
||||
os << " return "
|
||||
<< formatv(attr.getValueAsString("convertFromStorage"), attrVal)
|
||||
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
|
||||
<< ";\n }\n";
|
||||
}
|
||||
}
|
||||
|
@ -295,10 +213,10 @@ void OpEmitter::emitNamedOperands() {
|
|||
return this->getInstruction()->getOperand({1});
|
||||
}
|
||||
)";
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
const auto &op = operands[i];
|
||||
if (!StringRef(op.first).startswith(generatedArgName))
|
||||
os << formatv(operandMethods, op.first, i);
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
if (operand.name)
|
||||
os << formatv(operandMethods, operand.name->getAsUnquotedString(), i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -328,15 +246,17 @@ void OpEmitter::emitBuilder() {
|
|||
os << ", Type returnType" << i;
|
||||
|
||||
// Emit parameters for all operands
|
||||
for (const auto &pair : operands)
|
||||
os << ", Value* " << pair.first;
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i)
|
||||
os << ", Value* " << getArgumentName(op, i);
|
||||
|
||||
// Emit parameters for all attributes
|
||||
// TODO(antiagainst): Support default initializer for attributes
|
||||
for (const auto &pair : attrs) {
|
||||
const Record &attr = *pair.second->getDef();
|
||||
os << ", " << getAsStringOrDefault(attr, "storageType", "Attribute").trim()
|
||||
<< ' ' << pair.first;
|
||||
for (const auto &attr : op.getAttributes()) {
|
||||
if (attr.isDerived)
|
||||
break;
|
||||
const Record &def = *attr.record;
|
||||
os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim()
|
||||
<< ' ' << getAttributeName(attr);
|
||||
}
|
||||
|
||||
os << ") {\n";
|
||||
|
@ -350,19 +270,18 @@ void OpEmitter::emitBuilder() {
|
|||
}
|
||||
|
||||
// Push all operands to the result
|
||||
if (!operands.empty()) {
|
||||
os << " result->addOperands({" << operands.front().first;
|
||||
for (auto it = operands.begin() + 1, e = operands.end(); it != e; ++it)
|
||||
os << ", " << it->first;
|
||||
if (op.getNumOperands() > 0) {
|
||||
os << " result->addOperands({" << getArgumentName(op, 0);
|
||||
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
|
||||
os << ", " << getArgumentName(op, i);
|
||||
os << "});\n";
|
||||
}
|
||||
|
||||
// Push all attributes to the result
|
||||
for (const auto &pair : attrs) {
|
||||
StringRef name = pair.first;
|
||||
os << " result->addAttribute(\"" << name << "\", " << name << ");\n";
|
||||
}
|
||||
|
||||
for (const auto &attr : op.getAttributes())
|
||||
if (!attr.isDerived)
|
||||
os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
|
||||
getAttributeName(attr));
|
||||
os << " }\n";
|
||||
|
||||
// 2. Aggregated parameters
|
||||
|
@ -378,16 +297,16 @@ void OpEmitter::emitBuilder() {
|
|||
<< " result->addTypes(resultTypes);\n";
|
||||
|
||||
// Operands
|
||||
os << " assert(args.size() == " << operands.size()
|
||||
os << " assert(args.size() == " << op.getNumOperands()
|
||||
<< "u && \"mismatched number of parameters\");\n"
|
||||
<< " result->addOperands(args);\n\n";
|
||||
|
||||
// Attributes
|
||||
if (attrs.empty()) {
|
||||
if (op.getNumAttributes() > 0) {
|
||||
os << " assert(!attributes.size() && \"no attributes expected\");\n"
|
||||
<< " }\n";
|
||||
} else {
|
||||
os << " assert(attributes.size() >= " << attrs.size()
|
||||
os << " assert(attributes.size() >= " << op.getNumAttributes()
|
||||
<< "u && \"not enough attributes\");\n"
|
||||
<< " for (const auto& pair : attributes)\n"
|
||||
<< " result->addAttribute(pair.first, pair.second);\n"
|
||||
|
@ -424,14 +343,17 @@ void OpEmitter::emitVerifier() {
|
|||
auto valueInit = def.getValueInit("verifier");
|
||||
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
|
||||
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
|
||||
if (!hasCustomVerify && attrs.empty())
|
||||
if (!hasCustomVerify && op.getNumAttributes() == 0)
|
||||
return;
|
||||
|
||||
os << " bool verify() const {\n";
|
||||
// Verify the attributes have the correct type.
|
||||
for (const auto attr : attrs) {
|
||||
auto name = attr.first;
|
||||
if (!hasStringAttribute(*attr.second->getDef(), "storageType")) {
|
||||
for (const auto &attr : op.getAttributes()) {
|
||||
if (attr.isDerived)
|
||||
continue;
|
||||
|
||||
auto name = getAttributeName(attr);
|
||||
if (!hasStringAttribute(*attr.record, "storageType")) {
|
||||
os << " if (!this->getAttr(\"" << name
|
||||
<< "\")) return emitOpError(\"requires attribute '" << name
|
||||
<< "'\");\n";
|
||||
|
@ -439,10 +361,10 @@ void OpEmitter::emitVerifier() {
|
|||
}
|
||||
|
||||
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||
<< attr.second->getDef()->getValueAsString("storageType").trim()
|
||||
<< attr.record->getValueAsString("storageType").trim()
|
||||
<< ">()) return emitOpError(\"requires "
|
||||
<< attr.second->getDef()->getValueAsString("returnType").trim()
|
||||
<< " attribute '" << name << "'\");\n";
|
||||
<< attr.record->getValueAsString("returnType").trim() << " attribute '"
|
||||
<< name << "'\");\n";
|
||||
}
|
||||
|
||||
if (hasCustomVerify)
|
||||
|
@ -486,13 +408,13 @@ void OpEmitter::emitTraits() {
|
|||
}
|
||||
}
|
||||
|
||||
if ((hasVariadicOperands || hasAtLeastNOperands) && !operands.empty()) {
|
||||
if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
"Operands number definition is not consistent.");
|
||||
}
|
||||
|
||||
// Add operand size trait if defined explicitly.
|
||||
switch (operands.size()) {
|
||||
switch (op.getNumOperands()) {
|
||||
case 0:
|
||||
if (!hasVariadicOperands && !hasAtLeastNOperands)
|
||||
os << ", OpTrait::ZeroOperands";
|
||||
|
@ -501,7 +423,7 @@ void OpEmitter::emitTraits() {
|
|||
os << ", OpTrait::OneOperand";
|
||||
break;
|
||||
default:
|
||||
os << ", OpTrait::NOperands<" << operands.size() << ">::Impl";
|
||||
os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl";
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -517,8 +439,7 @@ void OpEmitter::emitTraits() {
|
|||
}
|
||||
|
||||
// Emits the opcode enum and op classes.
|
||||
static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||
const std::vector<Record *> &defs, raw_ostream &os) {
|
||||
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os) {
|
||||
IfDefScope scope("GET_OP_CLASSES", os);
|
||||
for (auto *def : defs)
|
||||
OpEmitter::emit(*def, os);
|
||||
|
@ -532,10 +453,7 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
|||
for (auto &def : defs) {
|
||||
if (!first)
|
||||
os << ",";
|
||||
|
||||
SmallVector<StringRef, 2> splittedDefName;
|
||||
SplitString(def->getName(), splittedDefName, "_");
|
||||
os << join(splittedDefName, "::");
|
||||
os << Operator(def).qualifiedCppClassName();
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
@ -546,7 +464,7 @@ static void emitOpDefinitions(const RecordKeeper &recordKeeper,
|
|||
|
||||
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
|
||||
emitOpList(defs, os);
|
||||
emitOpClasses(recordKeeper, defs, os);
|
||||
emitOpClasses(defs, os);
|
||||
}
|
||||
|
||||
static void emitOpDefFile(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
|
|
Loading…
Reference in New Issue