From dde5bf234d9841a9b15d42b7687cc4d12ffe4652 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 3 Jan 2019 15:53:54 -0800 Subject: [PATCH] Use Operator class in OpDefinitionsGen. Cleanup NFC. PiperOrigin-RevId: 227764826 --- mlir/include/mlir/TableGen/Operator.h | 4 + mlir/lib/TableGen/Operator.cpp | 36 ++-- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 228 +++++++------------- 3 files changed, 101 insertions(+), 167 deletions(-) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index fac9194ce6e4..95dc6b899b01 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -63,7 +63,9 @@ public: attribute_iterator attribute_begin(); attribute_iterator attribute_end(); llvm::iterator_range getAttributes(); + int getNumAttributes() { return attributes.size(); } Attribute &getAttribute(int index) { return attributes[index]; } + const Attribute &getAttribute(int index) const { return attributes[index]; } // Operations operand accessors. struct Operand { @@ -76,6 +78,8 @@ public: operand_iterator operand_end(); llvm::iterator_range getOperands(); Operand &getOperand(int index) { return operands[index]; } + const Operand &getOperand(int index) const { return operands[index]; } + int getNumOperands() { return operands.size(); } // Operations argument accessors. using Argument = llvm::PointerUnion; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index c4021d3cae0d..b702037d64e6 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -112,18 +112,30 @@ void Operator::populateOperandsAndAttributes() { if (!givenName) PrintFatalError(argDef->getLoc(), "attributes must be named"); bool isDerived = argDef->isSubClassOf(derivedAttrClass); - - // Update start of derived attributes or ensure that non-derived and derived - // attributes are not interleaved. - if (derivedAttrStart == -1) { - if (isDerived) - derivedAttrStart = i; - } else { - if (!isDerived) - PrintFatalError( - def.getLoc(), - "derived attributes have to follow non-derived attributes"); - } + if (isDerived) + PrintFatalError(def.getLoc(), + "derived attributes not allowed in argument list"); attributes.push_back({givenName, argDef, isDerived}); } + + // Derived attributes. + derivedAttrStart = i; + for (const auto &val : def.getValues()) { + if (auto *record = dyn_cast(val.getType())) { + if (!record->isSubClassOf(attrClass)) + continue; + if (!record->isSubClassOf(derivedAttrClass)) + PrintFatalError(def.getLoc(), + "unexpected Attr where only DerivedAttr is allowed"); + + if (record->getClasses().size() != 1) { + PrintFatalError( + def.getLoc(), + "unsupported attribute modelling, only single class expected"); + } + attributes.push_back({cast(val.getNameInit()), + cast(val.getValue())->getDef(), + /*isDerived=*/true}); + } + } } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index d08e1ab8dfd2..c55931546bf5 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Signals.h" @@ -29,6 +30,7 @@ #include "llvm/TableGen/TableGenBackend.h" using namespace llvm; +using namespace mlir; static const char *const generatedArgName = "_arg"; @@ -64,6 +66,18 @@ static inline StringRef getAsStringOrDefault(const Record &record, : defaultVal; } +static std::string getAttributeName(const Operator::Attribute &attr) { + return attr.name->getAsUnquotedString(); +} + +static std::string getArgumentName(const Operator &op, int index) { + const auto &operand = op.getOperand(index); + if (operand.name) + return operand.name->getAsUnquotedString(); + else + return formatv("{0}_{1}", generatedArgName, index); +} + namespace { // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { @@ -114,122 +128,41 @@ public: private: OpEmitter(const Record &def, raw_ostream &os); - // Populates the operands and attributes. - void getOperandsAndAttributes(); - - // Returns the class name of the op. - StringRef cppClassName() const; - // Invokes the given function over all the namespaces of the class. - void mapOverClassNamespaces(function_ref fn) const; - - // Returns the operation name. - StringRef getOperationName() const; + void mapOverClassNamespaces(function_ref fn); // The record corresponding to the op. const Record &def; - const RecordKeeper &recordKeeper; - - // Record of Attr class. - Record *attrClass; - - // Type of DerivedAttr. - const RecordRecTy *derivedAttrType; - - // The name of the op split around '_'. - SmallVector splittedDefName; - - // The operands of the op. - SmallVector, 4> operands; - - // The attributes of the op. - SmallVector, 4> attrs; - SmallVector, 4> derivedAttrs; + // The operator being emitted. + Operator op; raw_ostream &os; }; } // end anonymous namespace OpEmitter::OpEmitter(const Record &def, raw_ostream &os) - : def(def), recordKeeper(def.getRecords()), - attrClass(recordKeeper.getClass("Attr")), - derivedAttrType(recordKeeper.getClass("DerivedAttr")->getType()), os(os) { - SplitString(def.getName(), splittedDefName, "_"); - getOperandsAndAttributes(); -} + : def(def), op(def), os(os) {} -StringRef OpEmitter::cppClassName() const { return splittedDefName.back(); } - -StringRef OpEmitter::getOperationName() const { - return def.getValueAsString("opName"); -} - -void OpEmitter::mapOverClassNamespaces(function_ref fn) const { +void OpEmitter::mapOverClassNamespaces(function_ref fn) { + auto &splittedDefName = op.getSplitDefName(); for (auto it = splittedDefName.begin(), e = std::prev(splittedDefName.end()); it != e; ++it) fn(*it); } -void OpEmitter::getOperandsAndAttributes() { - DagInit *argumentValues = def.getValueAsDag("arguments"); - for (unsigned i = 0, e = argumentValues->getNumArgs(); i != e; ++i) { - auto arg = argumentValues->getArg(i); - auto givenName = argumentValues->getArgName(i); - DefInit *argDef = dyn_cast(arg); - if (!argDef) - PrintFatalError(def.getLoc(), - "unexpected type for " + Twine(i) + "th argument"); - - // Handle attribute. - if (argDef->getDef()->isSubClassOf(attrClass)) { - if (!givenName) - PrintFatalError(argDef->getDef()->getLoc(), "attributes must be named"); - attrs.emplace_back(givenName->getValue(), argDef); - continue; - } - - // Handle operands. - std::string name; - if (givenName) - name = givenName->getValue(); - else - name = formatv("{0}_{1}", generatedArgName, i); - operands.emplace_back(name, argDef); - } - - // Derived attributes. - for (const auto &val : def.getValues()) { - if (auto *record = dyn_cast(val.getType())) { - if (record->typeIsA(derivedAttrType)) { - if (record->getClasses().size() != 1) { - PrintFatalError( - def.getLoc(), - "unsupported attribute modelling, only single class expected"); - } - derivedAttrs.emplace_back(&val, *record->getClasses().begin()); - continue; - } - if (record->isSubClassOf(attrClass)) - PrintFatalError(def.getLoc(), - "unexpected Attr where only DerivedAttr is allowed"); - } - } -} - void OpEmitter::emit(const Record &def, raw_ostream &os) { OpEmitter emitter(def, os); emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "\nnamespace " << ns << "{\n"; }); - os << "class " << emitter.cppClassName() << " : public Op<" - << emitter.cppClassName(); + os << formatv("class {0} : public Op<{0}", emitter.op.cppClassName()); emitter.emitTraits(); os << "> {\npublic:\n"; // Build operation name. os << " static StringRef getOperationName() { return \"" - << emitter.getOperationName() << "\"; };\n"; + << emitter.op.getOperationName() << "\"; };\n"; emitter.emitNamedOperands(); emitter.emitBuilder(); @@ -240,49 +173,34 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { emitter.emitCanonicalizationPatterns(); os << "private:\n friend class ::mlir::OperationInst;\n"; - os << " explicit " << emitter.cppClassName() - << "(const OperationInst* state) : Op(state) {}\n"; - os << "};\n"; + os << " explicit " << emitter.op.cppClassName() + << "(const OperationInst* state) : Op(state) {}\n};\n"; emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; }); } void OpEmitter::emitAttrGetters() { - for (const auto &pair : derivedAttrs) { - auto &val = *pair.first; + for (auto &attr : op.getAttributes()) { + auto name = getAttributeName(attr); + auto *def = attr.record; // Emit the derived attribute body. - if (auto defInit = dyn_cast(val.getValue())) { - if (defInit->getType()->typeIsA(derivedAttrType)) { - auto *def = defInit->getDef(); - os << " " << def->getValueAsString("returnType").trim() << ' ' - << val.getName() << "() const {" << def->getValueAsString("body") - << " }\n"; - continue; - } - } - } - - for (const auto &pair : attrs) { - auto &name = pair.first; - auto &attr = *pair.second->getDef(); - // Emit normal emitter. - if (!hasStringAttribute(attr, "storageType")) { - // Handle the base case where there is no storage type specified. - os << " Attribute " << name << "() const {\n return getAttr(\"" - << name << "\");\n }\n"; + if (attr.isDerived) { + os << " " << def->getValueAsString("returnType").trim() << ' ' << name + << "() const {" << def->getValueAsString("body") << " }\n"; continue; } - os << " " << attr.getValueAsString("returnType").trim() << ' ' << name + // Emit normal emitter. + os << " " << def->getValueAsString("returnType").trim() << ' ' << name << "() const {\n"; // Return the queried attribute with the correct return type. std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")", - attr.getValueAsString("storageType").trim(), name); + def->getValueAsString("storageType").trim(), name); os << " return " - << formatv(attr.getValueAsString("convertFromStorage"), attrVal) + << formatv(def->getValueAsString("convertFromStorage"), attrVal) << ";\n }\n"; } } @@ -295,10 +213,10 @@ void OpEmitter::emitNamedOperands() { return this->getInstruction()->getOperand({1}); } )"; - for (int i = 0, e = operands.size(); i != e; ++i) { - const auto &op = operands[i]; - if (!StringRef(op.first).startswith(generatedArgName)) - os << formatv(operandMethods, op.first, i); + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name) + os << formatv(operandMethods, operand.name->getAsUnquotedString(), i); } } @@ -328,15 +246,17 @@ void OpEmitter::emitBuilder() { os << ", Type returnType" << i; // Emit parameters for all operands - for (const auto &pair : operands) - os << ", Value* " << pair.first; + for (int i = 0, e = op.getNumOperands(); i != e; ++i) + os << ", Value* " << getArgumentName(op, i); // Emit parameters for all attributes // TODO(antiagainst): Support default initializer for attributes - for (const auto &pair : attrs) { - const Record &attr = *pair.second->getDef(); - os << ", " << getAsStringOrDefault(attr, "storageType", "Attribute").trim() - << ' ' << pair.first; + for (const auto &attr : op.getAttributes()) { + if (attr.isDerived) + break; + const Record &def = *attr.record; + os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim() + << ' ' << getAttributeName(attr); } os << ") {\n"; @@ -350,19 +270,18 @@ void OpEmitter::emitBuilder() { } // Push all operands to the result - if (!operands.empty()) { - os << " result->addOperands({" << operands.front().first; - for (auto it = operands.begin() + 1, e = operands.end(); it != e; ++it) - os << ", " << it->first; + if (op.getNumOperands() > 0) { + os << " result->addOperands({" << getArgumentName(op, 0); + for (int i = 1, e = op.getNumOperands(); i != e; ++i) + os << ", " << getArgumentName(op, i); os << "});\n"; } // Push all attributes to the result - for (const auto &pair : attrs) { - StringRef name = pair.first; - os << " result->addAttribute(\"" << name << "\", " << name << ");\n"; - } - + for (const auto &attr : op.getAttributes()) + if (!attr.isDerived) + os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n", + getAttributeName(attr)); os << " }\n"; // 2. Aggregated parameters @@ -378,16 +297,16 @@ void OpEmitter::emitBuilder() { << " result->addTypes(resultTypes);\n"; // Operands - os << " assert(args.size() == " << operands.size() + os << " assert(args.size() == " << op.getNumOperands() << "u && \"mismatched number of parameters\");\n" << " result->addOperands(args);\n\n"; // Attributes - if (attrs.empty()) { + if (op.getNumAttributes() > 0) { os << " assert(!attributes.size() && \"no attributes expected\");\n" << " }\n"; } else { - os << " assert(attributes.size() >= " << attrs.size() + os << " assert(attributes.size() >= " << op.getNumAttributes() << "u && \"not enough attributes\");\n" << " for (const auto& pair : attributes)\n" << " result->addAttribute(pair.first, pair.second);\n" @@ -424,14 +343,17 @@ void OpEmitter::emitVerifier() { auto valueInit = def.getValueInit("verifier"); CodeInit *codeInit = dyn_cast(valueInit); bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); - if (!hasCustomVerify && attrs.empty()) + if (!hasCustomVerify && op.getNumAttributes() == 0) return; os << " bool verify() const {\n"; // Verify the attributes have the correct type. - for (const auto attr : attrs) { - auto name = attr.first; - if (!hasStringAttribute(*attr.second->getDef(), "storageType")) { + for (const auto &attr : op.getAttributes()) { + if (attr.isDerived) + continue; + + auto name = getAttributeName(attr); + if (!hasStringAttribute(*attr.record, "storageType")) { os << " if (!this->getAttr(\"" << name << "\")) return emitOpError(\"requires attribute '" << name << "'\");\n"; @@ -439,10 +361,10 @@ void OpEmitter::emitVerifier() { } os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" - << attr.second->getDef()->getValueAsString("storageType").trim() + << attr.record->getValueAsString("storageType").trim() << ">()) return emitOpError(\"requires " - << attr.second->getDef()->getValueAsString("returnType").trim() - << " attribute '" << name << "'\");\n"; + << attr.record->getValueAsString("returnType").trim() << " attribute '" + << name << "'\");\n"; } if (hasCustomVerify) @@ -486,13 +408,13 @@ void OpEmitter::emitTraits() { } } - if ((hasVariadicOperands || hasAtLeastNOperands) && !operands.empty()) { + if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) { PrintFatalError(def.getLoc(), "Operands number definition is not consistent."); } // Add operand size trait if defined explicitly. - switch (operands.size()) { + switch (op.getNumOperands()) { case 0: if (!hasVariadicOperands && !hasAtLeastNOperands) os << ", OpTrait::ZeroOperands"; @@ -501,7 +423,7 @@ void OpEmitter::emitTraits() { os << ", OpTrait::OneOperand"; break; default: - os << ", OpTrait::NOperands<" << operands.size() << ">::Impl"; + os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl"; break; } @@ -517,8 +439,7 @@ void OpEmitter::emitTraits() { } // Emits the opcode enum and op classes. -static void emitOpClasses(const RecordKeeper &recordKeeper, - const std::vector &defs, raw_ostream &os) { +static void emitOpClasses(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_CLASSES", os); for (auto *def : defs) OpEmitter::emit(*def, os); @@ -532,10 +453,7 @@ static void emitOpList(const std::vector &defs, raw_ostream &os) { for (auto &def : defs) { if (!first) os << ","; - - SmallVector splittedDefName; - SplitString(def->getName(), splittedDefName, "_"); - os << join(splittedDefName, "::"); + os << Operator(def).qualifiedCppClassName(); first = false; } } @@ -546,7 +464,7 @@ static void emitOpDefinitions(const RecordKeeper &recordKeeper, const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); emitOpList(defs, os); - emitOpClasses(recordKeeper, defs, os); + emitOpClasses(defs, os); } static void emitOpDefFile(const RecordKeeper &recordKeeper, raw_ostream &os) {