From aae85ddce104fd5dca4dbb1b97a8b9949b249c93 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 7 Jan 2019 09:52:26 -0800 Subject: [PATCH] Match attributes in input pattern. Bind attributes similar to operands. Use to rewrite leakyreulo and const rewrite pattern. The attribute type/attributes are not currently checked so should only be used where the attributes match due to the construction of the op. To support current attribute namespacing, convert __ in attribute name to "$" for matching purposes ('$' is not valid character in variable in TableGen). Some simplification to make it simpler to specify indented ostream and avoid so many spaces. The goal is not to have perfectly formatted code generated but good enough so that its still easy to read for a user. PiperOrigin-RevId: 228183639 --- mlir/include/mlir/TableGen/Operator.h | 4 + mlir/lib/TableGen/Operator.cpp | 17 ++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 103 ++++++++++---------- mlir/tools/mlir-tblgen/RewriterGen.cpp | 39 +++++--- 4 files changed, 99 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 523eb30e60a1..977090b777a0 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -55,6 +55,10 @@ public: // Operations attribute accessors. struct Attribute { + std::string getName() const; + StringRef getReturnType() const; + StringRef getStorageType() const; + llvm::StringInit *name; llvm::Record *record; bool isDerived; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 085d1db2cc8d..a126334207ff 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -141,6 +141,23 @@ void Operator::populateOperandsAndAttributes() { } } +std::string mlir::Operator::Attribute::getName() const { + std::string ret = name->getAsUnquotedString(); + // TODO(jpienaar): Revise this post dialect prefixing attribute discussion. + auto split = StringRef(ret).split("__"); + if (split.second.empty()) + return ret; + return llvm::join_items("$", split.first, split.second); +} + +StringRef mlir::Operator::Attribute::getReturnType() const { + return record->getValueAsString("returnType").trim(); +} + +StringRef mlir::Operator::Attribute::getStorageType() const { + return record->getValueAsString("storageType").trim(); +} + bool mlir::Operator::Operand::hasMatcher() const { llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue(); return !isa(matcher); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 99f69ccf1613..965f08d53304 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -34,6 +34,9 @@ using namespace mlir; static const char *const generatedArgName = "_arg"; +// Helper macro that returns indented os. +#define OUT(X) os.indent((X)) + // TODO(jpienaar): The builder body should probably be separate from the header. // Variation of method in FormatVariadic.h which takes a StringRef as input @@ -164,8 +167,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { os << "> {\npublic:\n"; // Build operation name. - os << " static StringRef getOperationName() { return \"" - << emitter.op.getOperationName() << "\"; };\n"; + OUT(2) << "static StringRef getOperationName() { return \"" + << emitter.op.getOperationName() << "\"; };\n"; emitter.emitNamedOperands(); emitter.emitBuilder(); @@ -176,8 +179,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { emitter.emitCanonicalizationPatterns(); emitter.emitConstantFolder(); - os << "private:\n friend class ::mlir::OperationInst;\n"; - os << " explicit " << emitter.op.cppClassName() + os << "private:\n friend class ::mlir::OperationInst;\n" + << " explicit " << emitter.op.cppClassName() << "(const OperationInst* state) : Op(state) {}\n};\n"; emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; }); @@ -190,22 +193,20 @@ void OpEmitter::emitAttrGetters() { // Emit the derived attribute body. if (attr.isDerived) { - os << " " << def->getValueAsString("returnType").trim() << ' ' << name - << "() const {" << def->getValueAsString("body") << " }\n"; + OUT(2) << attr.getReturnType() << ' ' << name << "() const {" + << def->getValueAsString("body") << " }\n"; continue; } // Emit normal emitter. - os << " " << def->getValueAsString("returnType").trim() << ' ' << name - << "() const {\n"; + OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n"; // Return the queried attribute with the correct return type. - std::string attrVal = - formatv("this->getAttrOfType<{0}>(\"{1}\")", - def->getValueAsString("storageType").trim(), name); - os << " return " - << formatv(def->getValueAsString("convertFromStorage"), attrVal) - << ";\n }\n"; + std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")", + attr.getStorageType(), name); + OUT(4) << "return " + << formatv(def->getValueAsString("convertFromStorage"), attrVal) + << ";\n }\n"; } } @@ -243,7 +244,7 @@ void OpEmitter::emitBuilder() { // 1. Stand-alone parameters std::vector returnTypes = def.getValueAsListOfDefs("returnTypes"); - os << " static void build(Builder* builder, OperationState* result"; + OUT(2) << "static void build(Builder* builder, OperationState* result"; // Emit parameters for all return types for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) @@ -267,7 +268,7 @@ void OpEmitter::emitBuilder() { // Push all result types to the result if (!returnTypes.empty()) { - os << " result->addTypes({returnType0"; + OUT(4) << "result->addTypes({returnType0"; for (unsigned i = 1, e = returnTypes.size(); i != e; ++i) os << ", returnType" << i; os << "});\n\n"; @@ -275,7 +276,7 @@ void OpEmitter::emitBuilder() { // Push all operands to the result if (op.getNumOperands() > 0) { - os << " result->addOperands({" << getArgumentName(op, 0); + OUT(4) << "result->addOperands({" << getArgumentName(op, 0); for (int i = 1, e = op.getNumOperands(); i != e; ++i) os << ", " << getArgumentName(op, i); os << "});\n"; @@ -284,45 +285,45 @@ void OpEmitter::emitBuilder() { // Push all attributes to the result for (const auto &attr : op.getAttributes()) if (!attr.isDerived) - os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n", - getAttributeName(attr)); - os << " }\n"; + OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n", + getAttributeName(attr)); + OUT(2) << "}\n"; // 2. Aggregated parameters // Signature - os << " static void build(Builder* builder, OperationState* result, " - << "ArrayRef resultTypes, ArrayRef args, " - "ArrayRef attributes) {\n"; + OUT(2) << "static void build(Builder* builder, OperationState* result, " + << "ArrayRef resultTypes, ArrayRef args, " + "ArrayRef attributes) {\n"; // Result types - os << " assert(resultTypes.size() == " << returnTypes.size() - << "u && \"mismatched number of return types\");\n" - << " result->addTypes(resultTypes);\n"; + OUT(4) << "assert(resultTypes.size() == " << returnTypes.size() + << "u && \"mismatched number of return types\");\n" + << " result->addTypes(resultTypes);\n"; // Operands - os << " assert(args.size() == " << op.getNumOperands() - << "u && \"mismatched number of parameters\");\n" - << " result->addOperands(args);\n\n"; + OUT(4) << "assert(args.size() == " << op.getNumOperands() + << "u && \"mismatched number of parameters\");\n" + << " result->addOperands(args);\n\n"; // Attributes if (op.getNumAttributes() > 0) { - os << " assert(!attributes.size() && \"no attributes expected\");\n" - << " }\n"; + OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n" + << " }\n"; } else { - os << " assert(attributes.size() >= " << op.getNumAttributes() - << "u && \"not enough attributes\");\n" - << " for (const auto& pair : attributes)\n" - << " result->addAttribute(pair.first, pair.second);\n" - << " }\n"; + OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes() + << "u && \"not enough attributes\");\n" + << " for (const auto& pair : attributes)\n" + << " result->addAttribute(pair.first, pair.second);\n" + << " }\n"; } } void OpEmitter::emitCanonicalizationPatterns() { if (!def.getValueAsBit("hasCanonicalizationPatterns")) return; - os << " static void getCanonicalizationPatterns(" - << "OwningRewritePatternList &results, MLIRContext* context);\n"; + OUT(2) << "static void getCanonicalizationPatterns(" + << "OwningRewritePatternList &results, MLIRContext* context);\n"; } void OpEmitter::emitConstantFolder() { @@ -363,7 +364,7 @@ void OpEmitter::emitVerifier() { if (!hasCustomVerify && op.getNumArgs() == 0) return; - os << " bool verify() const {\n"; + OUT(2) << "bool verify() const {\n"; // Verify the attributes have the correct type. for (const auto &attr : op.getAttributes()) { if (attr.isDerived) @@ -371,17 +372,15 @@ void OpEmitter::emitVerifier() { auto name = getAttributeName(attr); if (!hasStringAttribute(*attr.record, "storageType")) { - os << " if (!this->getAttr(\"" << name - << "\")) return emitOpError(\"requires attribute '" << name - << "'\");\n"; + OUT(4) << "if (!this->getAttr(\"" << name + << "\")) return emitOpError(\"requires attribute '" << name + << "'\");\n"; continue; } - os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" - << attr.record->getValueAsString("storageType").trim() - << ">()) return emitOpError(\"requires " - << attr.record->getValueAsString("returnType").trim() << " attribute '" - << name << "'\");\n"; + OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" + << attr.getStorageType() << ">()) return emitOpError(\"requires " + << attr.getReturnType() << " attribute '" << name << "'\");\n"; } // TODO: Handle variadic. @@ -392,17 +391,17 @@ void OpEmitter::emitVerifier() { if (operand.hasMatcher()) { auto pred = "if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n"; - os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" + - Twine(opIndex) + ")->getType()"); + OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" + + Twine(opIndex) + ")->getType()"); } ++opIndex; } if (hasCustomVerify) - os << " " << codeInit->getValue() << "\n"; + OUT(4) << codeInit->getValue() << "\n"; else - os << " return false;\n"; - os << " }\n"; + OUT(4) << "return false;\n"; + OUT(2) << "}\n"; } void OpEmitter::emitTraits() { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 6bc1366bd584..c83480d11301 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -150,6 +150,8 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, "' in pattern and op's definition"); for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { auto arg = tree->getArg(i); + auto opArg = op.getArg(i); + if (auto argTree = dyn_cast(arg)) { os.indent(indent) << "{\n"; os.indent(indent + 2) << formatv( @@ -162,12 +164,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, // Verify arguments. if (auto defInit = dyn_cast(arg)) { - auto opArg = op.getArg(i); // Verify operands. if (auto *operand = opArg.dyn_cast()) { // Skip verification where not needed due to definition of op. if (operand->defInit == defInit) - goto SkipOperandVerification; + goto StateCapture; if (!defInit->getDef()->isSubClassOf("Type")) PrintFatalError(pattern->getLoc(), @@ -185,15 +186,24 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, formatv("op{0}->getOperand({1})->getType()", depth, i)) << ")) return matchFailure();\n"; } - } - SkipOperandVerification: - // TODO(jpienaar): Verify attributes. + // TODO(jpienaar): Verify attributes. + if (auto *attr = opArg.dyn_cast()) { + } + } + + StateCapture: auto name = tree->getArgNameStr(i); if (name.empty()) continue; - os.indent(indent) << "state->" << name << " = op" << depth - << "->getOperand(" << i << ");\n"; + if (opArg.is()) + os.indent(indent) << "state->" << name << " = op" << depth + << "->getOperand(" << i << ");\n"; + if (auto attr = opArg.dyn_cast()) { + os.indent(indent) << "state->" << name << " = op" << depth + << "->getAttrOfType<" << attr->getStorageType() + << ">(\"" << attr->getName() << "\");\n"; + } } } @@ -291,13 +301,18 @@ void Pattern::emit(StringRef rewriteName) { (os << ",\n").indent(6); // The argument in the result DAG pattern. - auto name = resultOp.getArgName(i); + auto name = resultTree->getArgNameStr(i); + auto opName = resultOp.getArgName(i); auto defInit = dyn_cast(resultTree->getArg(i)); auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; - if (!value) - PrintFatalError(pattern->getLoc(), - Twine("attribute '") + name + - "' needs to be constant initialized"); + if (!value) { + if (boundArguments.find(name) == boundArguments.end()) + PrintFatalError(pattern->getLoc(), + Twine("referencing unbound variable '") + name + "'"); + os << "/*" << opName << "=*/" + << "s." << name; + continue; + } // TODO(jpienaar): Refactor out into map to avoid recomputing these. auto argument = resultOp.getArg(i);