From e0fc503896c3746824f466e567a2369156bc84a7 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 19 Feb 2019 07:15:59 -0800 Subject: [PATCH] [TableGen] Support using Variadic in results This CL extended TableGen Operator class to provide accessors for information on op results. In OpDefinitionGen, added checks to make sure only the last result can be variadic, and adjusted traits and builders generation to consider variadic results. PiperOrigin-RevId: 234596124 --- mlir/include/mlir/TableGen/Operator.h | 10 +++- mlir/lib/TableGen/Operator.cpp | 41 ++++++++++--- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 64 +++++++++++++++------ 3 files changed, 87 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index a4f5ffda3a94..ac5ceb5bdf97 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -72,6 +72,9 @@ public: // Returns the `index`-th result's name. StringRef getResultName(int index) const; + // Returns true if this operation has a variadic result. + bool hasVariadicResult() const; + // Op attribute interators. using attribute_iterator = const NamedAttribute *; attribute_iterator attribute_begin() const; @@ -123,8 +126,8 @@ public: StringRef getSummary() const; private: - // Populates the operands and attributes. - void populateOperandsAndAttributes(); + // Populates the vectors containing operands, attributes, and results. + void populateOpStructure(); // The name of the op split around '_'. SmallVector splittedDefName; @@ -135,6 +138,9 @@ private: // The attributes of the op. SmallVector attributes; + // The results of the op. + SmallVector results; + // The start of native attributes, which are specified when creating the op // as a part of the op's definition. int nativeAttrStart; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index c6e8c6c35bf2..db7f811fb58d 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -35,7 +35,7 @@ using llvm::Record; tblgen::Operator::Operator(const llvm::Record &def) : def(def) { SplitString(def.getName(), splittedDefName, "_"); - populateOperandsAndAttributes(); + populateOpStructure(); } const SmallVectorImpl &tblgen::Operator::getSplitDefName() const { @@ -63,13 +63,15 @@ int tblgen::Operator::getNumResults() const { } tblgen::Type tblgen::Operator::getResultType(int index) const { - DagInit *results = def.getValueAsDag("results"); - return Type(cast(results->getArg(index))); + return results[index].type; } StringRef tblgen::Operator::getResultName(int index) const { - DagInit *results = def.getValueAsDag("results"); - return results->getArgNameStr(index); + return results[index].name; +} + +bool tblgen::Operator::hasVariadicResult() const { + return !results.empty() && results.back().type.isVariadic(); } int tblgen::Operator::getNumNativeAttributes() const { @@ -127,7 +129,7 @@ auto tblgen::Operator::getArg(int index) -> Argument { return {&attributes[index - nativeAttrStart]}; } -void tblgen::Operator::populateOperandsAndAttributes() { +void tblgen::Operator::populateOpStructure() { auto &recordKeeper = def.getRecords(); auto attrClass = recordKeeper.getClass("Attr"); auto derivedAttrClass = recordKeeper.getClass("DerivedAttr"); @@ -144,7 +146,7 @@ void tblgen::Operator::populateOperandsAndAttributes() { auto argDefInit = dyn_cast(arg); if (!argDefInit) PrintFatalError(def.getLoc(), - Twine("undefined type for argument ") + Twine(i)); + Twine("undefined type for argument #") + Twine(i)); Record *argDef = argDefInit->getDef(); if (argDef->isSubClassOf(attrClass)) break; @@ -191,11 +193,36 @@ void tblgen::Operator::populateOperandsAndAttributes() { } } + // Verify that only the last operand can be variadic. for (int i = 0, e = operands.size() - 1; i < e; ++i) { if (operands[i].type.isVariadic()) PrintFatalError(def.getLoc(), "only the last operand allowed to be variadic"); } + + auto *resultsDag = def.getValueAsDag("results"); + auto *outsOp = dyn_cast(resultsDag->getOperator()); + if (!outsOp || outsOp->getDef()->getName() != "outs") { + PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); + } + + // Handle results. + for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { + auto name = resultsDag->getArgNameStr(i); + auto *resultDef = dyn_cast(resultsDag->getArg(i)); + if (!resultDef) { + PrintFatalError(def.getLoc(), + Twine("undefined type for result #") + Twine(i)); + } + results.push_back({name, Type(resultDef)}); + } + + // Verify that only the last result can be variadic. + for (int i = 0, e = results.size() - 1; i < e; ++i) { + if (results[i].type.isVariadic()) + PrintFatalError(def.getLoc(), + "only the last result allowed to be variadic"); + } } ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 1cf91cdeb0c8..786d784a3f4a 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -247,7 +247,8 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) { // Emit parameters for all return types if (!isAllSameType) { for (unsigned i = 0; i != numResults; ++i) - os << ", Type returnType" << i; + os << (op.getResultType(i).isVariadic() ? ", ArrayRef " : ", Type ") + << "returnType" << i; } // Emit parameters for all operands @@ -270,23 +271,36 @@ void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) { // Push all result types to the result if (numResults > 0) { - OUT(4) << "result->addTypes({"; if (!isAllSameType) { - os << "returnType0"; - for (unsigned i = 1; i != numResults; ++i) - os << ", returnType" << i; + bool hasVariadicResult = op.hasVariadicResult(); + int numNonVariadicResults = + numResults - static_cast(hasVariadicResult); + + if (numNonVariadicResults > 0) { + OUT(4) << "result->addTypes({returnType0"; + for (int i = 1; i < numNonVariadicResults; ++i) { + os << ", resultType" << i; + } + os << "});\n"; + } + + if (hasVariadicResult) { + OUT(4) << formatv("result->addTypes(returnType{0});\n", numResults - 1); + } } else { + OUT(4) << "result->addTypes({"; auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str(); os << resultType; for (unsigned i = 1; i != numResults; ++i) os << resultType; + os << "});\n\n"; } - os << "});\n\n"; } // Push all operands to the result bool hasVariadicOperand = op.hasVariadicOperand(); - int numNonVariadicOperands = numOperands - int(hasVariadicOperand); + int numNonVariadicOperands = + numOperands - static_cast(hasVariadicOperand); if (numNonVariadicOperands > 0) { OUT(4) << "result->addOperands({" << getArgumentName(op, 0); for (int i = 1; i < numNonVariadicOperands; ++i) { @@ -316,6 +330,8 @@ void OpEmitter::emitBuilder() { } auto numResults = op.getNumResults(); + bool hasVariadicResult = op.hasVariadicResult(); + int numNonVariadicResults = numResults - int(hasVariadicResult); auto numOperands = op.getNumOperands(); bool hasVariadicOperand = op.hasVariadicOperand(); @@ -345,7 +361,8 @@ void OpEmitter::emitBuilder() { "ArrayRef attributes) {\n"; // Result types - OUT(4) << "assert(resultTypes.size() == " << numResults + OUT(4) << "assert(resultTypes.size()" << (hasVariadicResult ? " >= " : " == ") + << numNonVariadicResults << "u && \"mismatched number of return types\");\n" << " result->addTypes(resultTypes);\n"; @@ -369,7 +386,7 @@ void OpEmitter::emitBuilder() { // 3. Deduced result types - if (op.hasTrait("SameOperandsAndResultType")) + if (!op.hasVariadicResult() && op.hasTrait("SameOperandsAndResultType")) emitStandaloneParamBuilder(/*isAllSameType=*/true); } @@ -501,18 +518,27 @@ void OpEmitter::emitVerifier() { void OpEmitter::emitTraits() { auto numResults = op.getNumResults(); + bool hasVariadicResult = op.hasVariadicResult(); // Add return size trait. - switch (numResults) { - case 0: - os << ", OpTrait::ZeroResult"; - break; - case 1: - os << ", OpTrait::OneResult"; - break; - default: - os << ", OpTrait::NResults<" << numResults << ">::Impl"; - break; + os << ", OpTrait::"; + if (hasVariadicResult) { + if (numResults == 1) + os << "VariadicResults"; + else + os << "AtLeastNResults<" << (numResults - 1) << ">::Impl"; + } else { + switch (numResults) { + case 0: + os << "ZeroResult"; + break; + case 1: + os << "OneResult"; + break; + default: + os << "NResults<" << numResults << ">::Impl"; + break; + } } // Add variadic size trait and normal op traits.