Enable useDefault{Type/Attribute}PrinterParser by default in ODS Dialect definition

The majority of dialects reimplement the same boilerplate over and over,
switching the default makes it for better discoverability and make it simpler
to implement new dialects.

Differential Revision: https://reviews.llvm.org/D117524
This commit is contained in:
Mehdi Amini 2022-01-18 06:33:21 +00:00
parent ade71641dc
commit c8e047f5e1
12 changed files with 30 additions and 105 deletions

View File

@ -19,9 +19,10 @@ include "mlir/IR/OpBase.td"
def Builtin_Dialect : Dialect { def Builtin_Dialect : Dialect {
let summary = let summary =
"A dialect containing the builtin Attributes, Operations, and Types"; "A dialect containing the builtin Attributes, Operations, and Types";
let name = "builtin"; let name = "builtin";
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
let extraClassDeclaration = [{ let extraClassDeclaration = [{
private: private:
// Register the builtin Attributes. // Register the builtin Attributes.

View File

@ -314,11 +314,11 @@ class Dialect {
// If this dialect should use default generated attribute parser boilerplate: // If this dialect should use default generated attribute parser boilerplate:
// it'll dispatch the parsing to every individual attributes directly. // it'll dispatch the parsing to every individual attributes directly.
bit useDefaultAttributePrinterParser = 0; bit useDefaultAttributePrinterParser = 1;
// If this dialect should use default generated type parser boilerplate: // If this dialect should use default generated type parser boilerplate:
// it'll dispatch the parsing to every individual types directly. // it'll dispatch the parsing to every individual types directly.
bit useDefaultTypePrinterParser = 0; bit useDefaultTypePrinterParser = 1;
// If this dialect overrides the hook for canonicalization patterns. // If this dialect overrides the hook for canonicalization patterns.
bit hasCanonicalizer = 0; bit hasCanonicalizer = 0;

View File

@ -346,22 +346,3 @@ Type ValueType::parse(mlir::AsmParser &parser) {
} }
return ValueType::get(ty); return ValueType::get(ty);
} }
/// Print a type registered to this dialect.
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'async' type kind");
}
/// Parse a type registered to this dialect.
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
StringRef typeTag;
if (parser.parseKeyword(&typeTag))
return Type();
Type genType;
auto parseResult = generatedTypeParser(parser, typeTag, genType);
if (parseResult.hasValue())
return genType;
parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
return {};
}

View File

@ -180,26 +180,6 @@ Attribute emitc::OpaqueAttr::parse(AsmParser &parser, Type type) {
return get(parser.getContext(), value); return get(parser.getContext(), value);
} }
Attribute EmitCDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
StringRef mnemonic;
if (parser.parseKeyword(&mnemonic))
return Attribute();
Attribute genAttr;
OptionalParseResult parseResult =
generatedAttributeParser(parser, mnemonic, type, genAttr);
if (parseResult.hasValue())
return genAttr;
parser.emitError(typeLoc, "unknown attribute in EmitC dialect");
return Attribute();
}
void EmitCDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
if (failed(generatedAttributePrinter(attr, os)))
llvm_unreachable("unexpected 'EmitC' attribute kind");
}
void emitc::OpaqueAttr::print(AsmPrinter &printer) const { void emitc::OpaqueAttr::print(AsmPrinter &printer) const {
printer << "<\""; printer << "<\"";
llvm::printEscapedString(getValue(), printer.getStream()); llvm::printEscapedString(getValue(), printer.getStream());

View File

@ -2906,28 +2906,3 @@ Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
llvm::sort(options, llvm::less_first()); llvm::sort(options, llvm::less_first());
return get(parser.getContext(), options); return get(parser.getContext(), options);
} }
Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
if (type) {
parser.emitError(parser.getNameLoc(), "unexpected type");
return {};
}
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
{
Attribute attr;
auto parseResult = generatedAttributeParser(parser, attrKind, type, attr);
if (parseResult.hasValue())
return attr;
}
parser.emitError(parser.getNameLoc(), "unknown attribute type: ") << attrKind;
return {};
}
void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
if (succeeded(generatedAttributePrinter(attr, os)))
return;
llvm_unreachable("Unknown attribute type");
}

View File

@ -53,15 +53,6 @@ static Type parsePDLType(AsmParser &parser) {
return Type(); return Type();
} }
Type PDLDialect::parseType(DialectAsmParser &parser) const {
return parsePDLType(parser);
}
void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
if (failed(generatedTypePrinter(type, printer)))
llvm_unreachable("unknown 'pdl' type");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PDL Types // PDL Types
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -346,22 +346,3 @@ void SparseTensorDialect::initialize() {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrTag;
if (failed(parser.parseKeyword(&attrTag)))
return Attribute();
Attribute attr;
auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
if (parseResult.hasValue())
return attr;
parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute");
return Attribute();
}
void SparseTensorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &printer) const {
if (succeeded(generatedAttributePrinter(attr, printer)))
return;
}

View File

@ -22,7 +22,7 @@ def Test_Dialect : Dialect {
let hasRegionResultAttrVerify = 1; let hasRegionResultAttrVerify = 1;
let hasOperationInterfaceFallback = 1; let hasOperationInterfaceFallback = 1;
let hasNonDefaultDestructor = 1; let hasNonDefaultDestructor = 1;
let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 0;
let dependentDialects = ["::mlir::DLTIDialect"]; let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
@ -36,6 +36,9 @@ def Test_Dialect : Dialect {
::mlir::OpAsmPrinter &printer)> ::mlir::OpAsmPrinter &printer)>
getOperationPrinter(::mlir::Operation *op) const override; getOperationPrinter(::mlir::Operation *op) const override;
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;
private: private:
// Storage for a custom fallback interface. // Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces; void *fallbackEffectOpInterfaces;

View File

@ -7,6 +7,7 @@ include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect { def Test_Dialect : Dialect {
let name = "TestDialect"; let name = "TestDialect";
let cppNamespace = "::test"; let cppNamespace = "::test";
let useDefaultTypePrinterParser = 0;
} }
class TestAttr<string name> : AttrDef<Test_Dialect, name>; class TestAttr<string name> : AttrDef<Test_Dialect, name>;

View File

@ -34,7 +34,6 @@ include "mlir/IR/OpBase.td"
def Test_Dialect: Dialect { def Test_Dialect: Dialect {
// DECL-NOT: TestDialect // DECL-NOT: TestDialect
// DEF-NOT: TestDialect
let name = "TestDialect"; let name = "TestDialect";
let cppNamespace = "::test"; let cppNamespace = "::test";
} }

View File

@ -616,9 +616,11 @@ public:
protected: protected:
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os, DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator) StringRef defType, StringRef valueType, bool isAttrGenerator,
bool needsDialectParserPrinter)
: defRecords(std::move(defs)), os(os), defType(defType), : defRecords(std::move(defs)), os(os), defType(defType),
valueType(valueType), isAttrGenerator(isAttrGenerator) {} valueType(valueType), isAttrGenerator(isAttrGenerator),
needsDialectParserPrinter(needsDialectParserPrinter) {}
/// Emit the list of def type names. /// Emit the list of def type names.
void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs); void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
@ -637,19 +639,29 @@ protected:
/// Flag indicating if this generator is for Attributes. False if the /// Flag indicating if this generator is for Attributes. False if the
/// generator is for types. /// generator is for types.
bool isAttrGenerator; bool isAttrGenerator;
/// Track if we need to emit the printAttribute/parseAttribute
/// implementations.
bool needsDialectParserPrinter;
}; };
/// A specialized generator for AttrDefs. /// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator { struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os, "Attr", : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os, "Attr",
"Attribute", /*isAttrGenerator=*/true) {} "Attribute",
/*isAttrGenerator=*/true,
/*needsDialectParserPrinter=*/
!records.getAllDerivedDefinitions("DialectAttr").empty()) {
}
}; };
/// A specialized generator for TypeDefs. /// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator { struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os, "Type", : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os, "Type",
"Type", /*isAttrGenerator=*/false) {} "Type", /*isAttrGenerator=*/false,
/*needsDialectParserPrinter=*/
!records.getAllDerivedDefinitions("DialectType").empty()) {
}
}; };
} // namespace } // namespace
@ -860,7 +872,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
Dialect firstDialect = defs.front().getDialect(); Dialect firstDialect = defs.front().getDialect();
// Emit the default parser/printer for Attributes if the dialect asked for // Emit the default parser/printer for Attributes if the dialect asked for
// it. // it.
if (valueType == "Attribute" && if (valueType == "Attribute" && needsDialectParserPrinter &&
firstDialect.useDefaultAttributePrinterParser()) { firstDialect.useDefaultAttributePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect); NamespaceEmitter nsEmitter(os, firstDialect);
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
@ -868,7 +880,8 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
} }
// Emit the default parser/printer for Types if the dialect asked for it. // Emit the default parser/printer for Types if the dialect asked for it.
if (valueType == "Type" && firstDialect.useDefaultTypePrinterParser()) { if (valueType == "Type" && needsDialectParserPrinter &&
firstDialect.useDefaultTypePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect); NamespaceEmitter nsEmitter(os, firstDialect);
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName()); firstDialect.getCppClassName());

View File

@ -210,9 +210,9 @@ emitDialectDecl(Dialect &dialect,
// Check for any attributes/types registered to this dialect. If there are, // Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing. // add the hooks for parsing/printing.
if (!dialectAttrs.empty() || dialect.useDefaultAttributePrinterParser()) if (!dialectAttrs.empty() && dialect.useDefaultAttributePrinterParser())
os << attrParserDecl; os << attrParserDecl;
if (!dialectTypes.empty() || dialect.useDefaultTypePrinterParser()) if (!dialectTypes.empty() && dialect.useDefaultTypePrinterParser())
os << typeParserDecl; os << typeParserDecl;
// Add the decls for the various features of the dialect. // Add the decls for the various features of the dialect.