diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 523eb30e60a1..977090b777a0 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -55,6 +55,10 @@ public: // Operations attribute accessors. struct Attribute { + std::string getName() const; + StringRef getReturnType() const; + StringRef getStorageType() const; + llvm::StringInit *name; llvm::Record *record; bool isDerived; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 085d1db2cc8d..a126334207ff 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -141,6 +141,23 @@ void Operator::populateOperandsAndAttributes() { } } +std::string mlir::Operator::Attribute::getName() const { + std::string ret = name->getAsUnquotedString(); + // TODO(jpienaar): Revise this post dialect prefixing attribute discussion. + auto split = StringRef(ret).split("__"); + if (split.second.empty()) + return ret; + return llvm::join_items("$", split.first, split.second); +} + +StringRef mlir::Operator::Attribute::getReturnType() const { + return record->getValueAsString("returnType").trim(); +} + +StringRef mlir::Operator::Attribute::getStorageType() const { + return record->getValueAsString("storageType").trim(); +} + bool mlir::Operator::Operand::hasMatcher() const { llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue(); return !isa(matcher); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 99f69ccf1613..965f08d53304 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -34,6 +34,9 @@ using namespace mlir; static const char *const generatedArgName = "_arg"; +// Helper macro that returns indented os. +#define OUT(X) os.indent((X)) + // TODO(jpienaar): The builder body should probably be separate from the header. // Variation of method in FormatVariadic.h which takes a StringRef as input @@ -164,8 +167,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { os << "> {\npublic:\n"; // Build operation name. - os << " static StringRef getOperationName() { return \"" - << emitter.op.getOperationName() << "\"; };\n"; + OUT(2) << "static StringRef getOperationName() { return \"" + << emitter.op.getOperationName() << "\"; };\n"; emitter.emitNamedOperands(); emitter.emitBuilder(); @@ -176,8 +179,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { emitter.emitCanonicalizationPatterns(); emitter.emitConstantFolder(); - os << "private:\n friend class ::mlir::OperationInst;\n"; - os << " explicit " << emitter.op.cppClassName() + os << "private:\n friend class ::mlir::OperationInst;\n" + << " explicit " << emitter.op.cppClassName() << "(const OperationInst* state) : Op(state) {}\n};\n"; emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; }); @@ -190,22 +193,20 @@ void OpEmitter::emitAttrGetters() { // Emit the derived attribute body. if (attr.isDerived) { - os << " " << def->getValueAsString("returnType").trim() << ' ' << name - << "() const {" << def->getValueAsString("body") << " }\n"; + OUT(2) << attr.getReturnType() << ' ' << name << "() const {" + << def->getValueAsString("body") << " }\n"; continue; } // Emit normal emitter. - os << " " << def->getValueAsString("returnType").trim() << ' ' << name - << "() const {\n"; + OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n"; // Return the queried attribute with the correct return type. - std::string attrVal = - formatv("this->getAttrOfType<{0}>(\"{1}\")", - def->getValueAsString("storageType").trim(), name); - os << " return " - << formatv(def->getValueAsString("convertFromStorage"), attrVal) - << ";\n }\n"; + std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")", + attr.getStorageType(), name); + OUT(4) << "return " + << formatv(def->getValueAsString("convertFromStorage"), attrVal) + << ";\n }\n"; } } @@ -243,7 +244,7 @@ void OpEmitter::emitBuilder() { // 1. Stand-alone parameters std::vector returnTypes = def.getValueAsListOfDefs("returnTypes"); - os << " static void build(Builder* builder, OperationState* result"; + OUT(2) << "static void build(Builder* builder, OperationState* result"; // Emit parameters for all return types for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) @@ -267,7 +268,7 @@ void OpEmitter::emitBuilder() { // Push all result types to the result if (!returnTypes.empty()) { - os << " result->addTypes({returnType0"; + OUT(4) << "result->addTypes({returnType0"; for (unsigned i = 1, e = returnTypes.size(); i != e; ++i) os << ", returnType" << i; os << "});\n\n"; @@ -275,7 +276,7 @@ void OpEmitter::emitBuilder() { // Push all operands to the result if (op.getNumOperands() > 0) { - os << " result->addOperands({" << getArgumentName(op, 0); + OUT(4) << "result->addOperands({" << getArgumentName(op, 0); for (int i = 1, e = op.getNumOperands(); i != e; ++i) os << ", " << getArgumentName(op, i); os << "});\n"; @@ -284,45 +285,45 @@ void OpEmitter::emitBuilder() { // Push all attributes to the result for (const auto &attr : op.getAttributes()) if (!attr.isDerived) - os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n", - getAttributeName(attr)); - os << " }\n"; + OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n", + getAttributeName(attr)); + OUT(2) << "}\n"; // 2. Aggregated parameters // Signature - os << " static void build(Builder* builder, OperationState* result, " - << "ArrayRef resultTypes, ArrayRef args, " - "ArrayRef attributes) {\n"; + OUT(2) << "static void build(Builder* builder, OperationState* result, " + << "ArrayRef resultTypes, ArrayRef args, " + "ArrayRef attributes) {\n"; // Result types - os << " assert(resultTypes.size() == " << returnTypes.size() - << "u && \"mismatched number of return types\");\n" - << " result->addTypes(resultTypes);\n"; + OUT(4) << "assert(resultTypes.size() == " << returnTypes.size() + << "u && \"mismatched number of return types\");\n" + << " result->addTypes(resultTypes);\n"; // Operands - os << " assert(args.size() == " << op.getNumOperands() - << "u && \"mismatched number of parameters\");\n" - << " result->addOperands(args);\n\n"; + OUT(4) << "assert(args.size() == " << op.getNumOperands() + << "u && \"mismatched number of parameters\");\n" + << " result->addOperands(args);\n\n"; // Attributes if (op.getNumAttributes() > 0) { - os << " assert(!attributes.size() && \"no attributes expected\");\n" - << " }\n"; + OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n" + << " }\n"; } else { - os << " assert(attributes.size() >= " << op.getNumAttributes() - << "u && \"not enough attributes\");\n" - << " for (const auto& pair : attributes)\n" - << " result->addAttribute(pair.first, pair.second);\n" - << " }\n"; + OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes() + << "u && \"not enough attributes\");\n" + << " for (const auto& pair : attributes)\n" + << " result->addAttribute(pair.first, pair.second);\n" + << " }\n"; } } void OpEmitter::emitCanonicalizationPatterns() { if (!def.getValueAsBit("hasCanonicalizationPatterns")) return; - os << " static void getCanonicalizationPatterns(" - << "OwningRewritePatternList &results, MLIRContext* context);\n"; + OUT(2) << "static void getCanonicalizationPatterns(" + << "OwningRewritePatternList &results, MLIRContext* context);\n"; } void OpEmitter::emitConstantFolder() { @@ -363,7 +364,7 @@ void OpEmitter::emitVerifier() { if (!hasCustomVerify && op.getNumArgs() == 0) return; - os << " bool verify() const {\n"; + OUT(2) << "bool verify() const {\n"; // Verify the attributes have the correct type. for (const auto &attr : op.getAttributes()) { if (attr.isDerived) @@ -371,17 +372,15 @@ void OpEmitter::emitVerifier() { auto name = getAttributeName(attr); if (!hasStringAttribute(*attr.record, "storageType")) { - os << " if (!this->getAttr(\"" << name - << "\")) return emitOpError(\"requires attribute '" << name - << "'\");\n"; + OUT(4) << "if (!this->getAttr(\"" << name + << "\")) return emitOpError(\"requires attribute '" << name + << "'\");\n"; continue; } - os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" - << attr.record->getValueAsString("storageType").trim() - << ">()) return emitOpError(\"requires " - << attr.record->getValueAsString("returnType").trim() << " attribute '" - << name << "'\");\n"; + OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" + << attr.getStorageType() << ">()) return emitOpError(\"requires " + << attr.getReturnType() << " attribute '" << name << "'\");\n"; } // TODO: Handle variadic. @@ -392,17 +391,17 @@ void OpEmitter::emitVerifier() { if (operand.hasMatcher()) { auto pred = "if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n"; - os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" + - Twine(opIndex) + ")->getType()"); + OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" + + Twine(opIndex) + ")->getType()"); } ++opIndex; } if (hasCustomVerify) - os << " " << codeInit->getValue() << "\n"; + OUT(4) << codeInit->getValue() << "\n"; else - os << " return false;\n"; - os << " }\n"; + OUT(4) << "return false;\n"; + OUT(2) << "}\n"; } void OpEmitter::emitTraits() { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 6bc1366bd584..c83480d11301 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -150,6 +150,8 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, "' in pattern and op's definition"); for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { auto arg = tree->getArg(i); + auto opArg = op.getArg(i); + if (auto argTree = dyn_cast(arg)) { os.indent(indent) << "{\n"; os.indent(indent + 2) << formatv( @@ -162,12 +164,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, // Verify arguments. if (auto defInit = dyn_cast(arg)) { - auto opArg = op.getArg(i); // Verify operands. if (auto *operand = opArg.dyn_cast()) { // Skip verification where not needed due to definition of op. if (operand->defInit == defInit) - goto SkipOperandVerification; + goto StateCapture; if (!defInit->getDef()->isSubClassOf("Type")) PrintFatalError(pattern->getLoc(), @@ -185,15 +186,24 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, formatv("op{0}->getOperand({1})->getType()", depth, i)) << ")) return matchFailure();\n"; } - } - SkipOperandVerification: - // TODO(jpienaar): Verify attributes. + // TODO(jpienaar): Verify attributes. + if (auto *attr = opArg.dyn_cast()) { + } + } + + StateCapture: auto name = tree->getArgNameStr(i); if (name.empty()) continue; - os.indent(indent) << "state->" << name << " = op" << depth - << "->getOperand(" << i << ");\n"; + if (opArg.is()) + os.indent(indent) << "state->" << name << " = op" << depth + << "->getOperand(" << i << ");\n"; + if (auto attr = opArg.dyn_cast()) { + os.indent(indent) << "state->" << name << " = op" << depth + << "->getAttrOfType<" << attr->getStorageType() + << ">(\"" << attr->getName() << "\");\n"; + } } } @@ -291,13 +301,18 @@ void Pattern::emit(StringRef rewriteName) { (os << ",\n").indent(6); // The argument in the result DAG pattern. - auto name = resultOp.getArgName(i); + auto name = resultTree->getArgNameStr(i); + auto opName = resultOp.getArgName(i); auto defInit = dyn_cast(resultTree->getArg(i)); auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; - if (!value) - PrintFatalError(pattern->getLoc(), - Twine("attribute '") + name + - "' needs to be constant initialized"); + if (!value) { + if (boundArguments.find(name) == boundArguments.end()) + PrintFatalError(pattern->getLoc(), + Twine("referencing unbound variable '") + name + "'"); + os << "/*" << opName << "=*/" + << "s." << name; + continue; + } // TODO(jpienaar): Refactor out into map to avoid recomputing these. auto argument = resultOp.getArg(i);