diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index af77133cefde..90f729038697 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -338,6 +338,9 @@ single output. (e.g., bias add rule only matches the case where both Tensors have F32 elements). + 1. Attributes can be transformed by transform rules to produce an attribute + of a type different than the type matched. + TODO: Add constraints on the matching rules. TODO: Describe the generation of benefit metric given pattern. diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 3f3f17df71e9..4a9614d842b2 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -224,6 +224,9 @@ class Attr { // '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to // 'builder.getStringAttr("foo")'. code constBuilderCall = ?; + + // TODO(jpienaar): Add predicate to verify the validity of Attr so + // that verification can be generated. } // A generic attribute that must be constructed around a specific type. @@ -249,7 +252,12 @@ class IntegerAttrBase : TypeBasedAttr; def BoolAttr : Attr { let storageType = [{ BoolAttr }]; let returnType = [{ bool }]; - let constBuilderCall = [{ {0}.getBoolAttr({1})" }]; + let constBuilderCall = [{ {0}.getBoolAttr({1}) }]; +} +def ArrayAttr : Attr { + let storageType = [{ ArrayAttr }]; + let returnType = [{ ArrayAttr }]; + code convertFromStorage = "{0}"; } def ElementsAttr : Attr { let storageType = [{ ElementsAttr }]; @@ -407,4 +415,21 @@ class Pattern resultOps> { // Form of a pattern which produces a single result. class Pat : Pattern; +// Attribute matcher. This is the base class to specify a predicate +// that has to match. Used on the input attributes of a rewrite rule. +class mAttr { + // Code to match the attribute. + // Format: {0} represents the attribute. + CPred predicate = pred; +} + +// Attribute transforms. This is the base class to specify a +// transformation of a matched attribute. Used on the output of a rewrite +// rule. +class tAttr { + // Code to transform the attribute. + // Format: {0} represents the attribute. + code attrTransform = transform; +} + #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 4dc98596ea29..d6c550a0f7a7 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -40,7 +40,7 @@ namespace tblgen { class Attribute { public: explicit Attribute(const llvm::Record &def); - explicit Attribute(const llvm::Record *def) : Attribute(*def) {} + explicit Attribute(const llvm::Record *def); explicit Attribute(const llvm::DefInit *init); // Returns true if this attribute is a derived attribute (i.e., a subclass @@ -79,7 +79,7 @@ public: private: // The TableGen definition of this attribute. - const llvm::Record &def; + const llvm::Record *def; }; } // end namespace tblgen diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 59e45d7bfd77..46a784e831e3 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -35,25 +35,27 @@ static StringRef getValueAsString(const llvm::Init *init) { return {}; } -tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) { - assert(def.isSubClassOf("Attr") && +tblgen::Attribute::Attribute(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Attr") && "must be subclass of TableGen 'Attr' class"); } +tblgen::Attribute::Attribute(const llvm::Record &def) : Attribute(&def) {} + tblgen::Attribute::Attribute(const llvm::DefInit *init) : Attribute(*init->getDef()) {} bool tblgen::Attribute::isDerivedAttr() const { - return def.isSubClassOf("DerivedAttr"); + return def->isSubClassOf("DerivedAttr"); } bool tblgen::Attribute::hasStorageType() const { - const auto *init = def.getValueInit("storageType"); + const auto *init = def->getValueInit("storageType"); return !getValueAsString(init).empty(); } StringRef tblgen::Attribute::getStorageType() const { - const auto *init = def.getValueInit("storageType"); + const auto *init = def->getValueInit("storageType"); auto type = getValueAsString(init); if (type.empty()) return "Attribute"; @@ -61,30 +63,30 @@ StringRef tblgen::Attribute::getStorageType() const { } StringRef tblgen::Attribute::getReturnType() const { - const auto *init = def.getValueInit("returnType"); + const auto *init = def->getValueInit("returnType"); return getValueAsString(init); } StringRef tblgen::Attribute::getConvertFromStorageCall() const { - const auto *init = def.getValueInit("convertFromStorage"); + const auto *init = def->getValueInit("convertFromStorage"); return getValueAsString(init); } bool tblgen::Attribute::isConstBuildable() const { - const auto *init = def.getValueInit("constBuilderCall"); + const auto *init = def->getValueInit("constBuilderCall"); return !getValueAsString(init).empty(); } StringRef tblgen::Attribute::getConstBuilderTemplate() const { - const auto *init = def.getValueInit("constBuilderCall"); + const auto *init = def->getValueInit("constBuilderCall"); return getValueAsString(init); } StringRef tblgen::Attribute::getTableGenDefName() const { - return def.getName(); + return def->getName(); } StringRef tblgen::Attribute::getDerivedCodeBody() const { assert(isDerivedAttr() && "only derived attribute has 'body' field"); - return def.getValueAsString("body"); + return def->getValueAsString("body"); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index dbddd5a8102f..25d86393f8c6 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -190,7 +190,7 @@ void OpEmitter::emitAttrGetters() { OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n"; // Return the queried attribute with the correct return type. - std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")", + std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()", attr.getStorageType(), name); OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal) << ";\n }\n"; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index d8d7a391b151..b4839f5968a9 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -44,21 +44,19 @@ using mlir::tblgen::Type; namespace { -// Wrapper around dag argument. +// Wrapper around DAG argument. struct DagArg { - DagArg(Init *init) : init(init) {} + DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit) + : arg(arg), constraintInit(constraintInit) {} bool isAttr(); - Init *init; + mlir::tblgen::Operator::Argument arg; + Init *constraintInit; }; } // end namespace -bool DagArg::isAttr() { - if (auto defInit = dyn_cast(init)) - return defInit->getDef()->isSubClassOf("Attr"); - return false; -} +bool DagArg::isAttr() { return arg.is(); } namespace { class Pattern { @@ -80,9 +78,18 @@ private: // Collect bound arguments. void collectBoundArguments(DagInit *tree); + // Helper function to match patterns. + void matchOp(DagInit *tree, int depth); + + // Returns the Operator stored for the given record. + Operator &getOperator(const llvm::Record *record); + // Map from bound argument name to DagArg. StringMap boundArguments; + // Map from Record* to Operator. + DenseMap opMap; + // Number of the operations in the input pattern. int numberOfOpsMatched = 0; @@ -91,6 +98,11 @@ private: }; } // end namespace +// Returns the Operator stored for the given record. +auto Pattern::getOperator(const llvm::Record *record) -> Operator & { + return opMap.try_emplace(record, record).first->second; +} + void Pattern::emitAttributeValue(Record *constAttr) { Attribute attr(constAttr->getValueAsDef("attr")); auto value = constAttr->getValue("value"); @@ -107,6 +119,7 @@ void Pattern::emitAttributeValue(Record *constAttr) { void Pattern::collectBoundArguments(DagInit *tree) { ++numberOfOpsMatched; + Operator &op = getOperator(cast(tree->getOperator())->getDef()); // TODO(jpienaar): Expand to multiple matches. for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { auto arg = tree->getArg(i); @@ -117,14 +130,13 @@ void Pattern::collectBoundArguments(DagInit *tree) { auto name = tree->getArgNameStr(i); if (name.empty()) continue; - boundArguments.try_emplace(name, arg); + boundArguments.try_emplace(name, op.getArg(i), arg); } } // Helper function to match patterns. -static void matchOp(Record *pattern, DagInit *tree, int depth, - raw_ostream &os) { - Operator op(cast(tree->getOperator())->getDef()); +void Pattern::matchOp(DagInit *tree, int depth) { + Operator &op = getOperator(cast(tree->getOperator())->getDef()); int indent = 4 + 2 * depth; // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { @@ -148,7 +160,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, os.indent(indent + 2) << formatv( "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", depth + 1, depth, i); - matchOp(pattern, argTree, depth + 1, os); + matchOp(argTree, depth + 1); os.indent(indent) << "}\n"; continue; } @@ -174,7 +186,19 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, } // TODO(jpienaar): Verify attributes. - if (auto *attr = opArg.dyn_cast()) { + if (auto *namedAttr = opArg.dyn_cast()) { + // TODO(jpienaar): move to helper class. + if (defInit->getDef()->isSubClassOf("mAttr")) { + auto pred = + tblgen::Pred(defInit->getDef()->getValueInit("predicate")); + os.indent(indent) + << "if (!(" + << formatv(pred.getCondition().str().c_str(), + formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, + namedAttr->attr.getStorageType(), + namedAttr->getName())) + << ")) return matchFailure();\n"; + } } } @@ -202,7 +226,7 @@ void Pattern::emitMatcher(DagInit *tree) { if (op0->getNumResults() != 1) return matchFailure(); auto state = std::make_unique();)" << "\n"; - matchOp(pattern, tree, 0, os); + matchOp(tree, 0); os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; } @@ -224,9 +248,9 @@ void Pattern::emit(StringRef rewriteName) { // Emit matched state. os << " struct MatchedState : public PatternState {\n"; for (auto &arg : boundArguments) { - if (arg.second.isAttr()) { - DefInit *defInit = cast(arg.second.init); - os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first() + if (auto namedAttr = + arg.second.arg.dyn_cast()) { + os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() << ";\n"; } else { os.indent(4) << "Value* " << arg.first() << ";\n"; @@ -247,7 +271,7 @@ void Pattern::emit(StringRef rewriteName) { } DefInit *resultRoot = cast(resultTree->getOperator()); - Operator resultOp(*resultRoot->getDef()); + Operator &resultOp = getOperator(resultRoot->getDef()); auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); os << formatv(R"( @@ -296,8 +320,19 @@ void Pattern::emit(StringRef rewriteName) { if (boundArguments.find(name) == boundArguments.end()) PrintFatalError(pattern->getLoc(), Twine("referencing unbound variable '") + name + "'"); - os << "/*" << opName << "=*/" - << "s." << name; + auto result = "s." + name; + os << "/*" << opName << "=*/"; + if (defInit) { + auto transform = defInit->getDef(); + if (transform->isSubClassOf("tAttr")) { + // TODO(jpienaar): move to helper class. + os << formatv( + transform->getValueAsString("attrTransform").str().c_str(), + result); + continue; + } + } + os << result; continue; }