diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 5b186ddb155c..58e6867f547f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -495,6 +495,10 @@ class DefaultValuedAttr : let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = val; + + // Remember `attr`'s def name. + // TOOD(b/132458159): consider embedding Attr as a field. + string baseAttr = !cast(attr); } // Decorates an attribute as optional. The return type of the generated @@ -507,6 +511,10 @@ class OptionalAttr : Attr { let convertFromStorage = "$_self ? " # returnType # "(" # attr.convertFromStorage # ") : (llvm::None)"; let isOptional = 0b1; + + // Remember `attr`'s def name. + // TOOD(b/132458159): consider embedding Attr as a field. + string baseAttr = !cast(attr); } // A generic attribute that must be constructed around a specific type diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index f8de36d753ec..f9216055d20b 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -92,7 +92,10 @@ public: // Returns whether this attribute is optional. bool isOptional() const; - StringRef getTableGenDefName() const; + // Returns this attribute's TableGen def name. If this is an `OptionalAttr` + // or `DefaultValuedAttr` without explicit name, returns the base attribute's + // name. + StringRef getAttrDefName() const; // Returns the code body for derived attribute. Aborts if this is not a // derived attribute. diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index a165ba8be4fb..6e4083c6e58e 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -110,7 +110,9 @@ bool tblgen::Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } -StringRef tblgen::Attribute::getTableGenDefName() const { +StringRef tblgen::Attribute::getAttrDefName() const { + if (def->isAnonymous() && (isOptional() || hasDefaultValueInitializer())) + return getValueAsString(def->getValueInit("baseAttr")); return def->getName(); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 26a8f8e5b499..9cf85079ec91 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -234,7 +234,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, std::string PatternEmitter::handleConstantAttr(Attribute attr, StringRef value) { if (!attr.isConstBuildable()) - PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + + PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + " does not have the 'constBuilderCall' field"); // TODO(jpienaar): Verify the constants here