diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md index 2aadbb6035d0..340940f38547 100644 --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1225,7 +1225,7 @@ was imported: - Imported `Type` constraints utilize the `cppClassName` field for native type translation. * `AttrInterface`/`OpInterface`/`TypeInterface` constraints - - Imported interfaces utilize the `cppClassName` field for native type translation. + - Imported interfaces utilize the `cppInterfaceName` field for native type translation. #### Defining Constraints Inline diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 89c7122b5ba7..16174454b7a1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1949,7 +1949,7 @@ class Interface { string description = ""; // The name given to the c++ interface class. - string cppClassName = name; + string cppInterfaceName = name; // The C++ namespace that this interface should be placed into. // @@ -1970,13 +1970,25 @@ class Interface { } // AttrInterface represents an interface registered to an attribute. -class AttrInterface : Interface, InterfaceTrait; +class AttrInterface : Interface, InterfaceTrait, + Attr()">, + name # " instance"> +{ + let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name; + let returnType = storageType; + let convertFromStorage = "$_self"; +} // OpInterface represents an interface registered to an operation. class OpInterface : Interface, OpInterfaceTrait; // TypeInterface represents an interface registered to a type. -class TypeInterface : Interface, InterfaceTrait; +class TypeInterface : Interface, InterfaceTrait, + Type()">, + name # " instance", + !if(!empty(cppNamespace),"", cppNamespace # "::") # name>; // Whether to declare the interface methods in the user entity's header. This // class simply wraps an Interface but is used to indicate that the method @@ -1992,27 +2004,27 @@ class DeclareInterfaceMethods overridenMethods = []> { class DeclareAttrInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - AttrInterface { + AttrInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareOpInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - OpInterface { + OpInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareTypeInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - TypeInterface { + TypeInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index 4d72ceeb45fc..1ee0b140756f 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -81,7 +81,7 @@ Interface::Interface(const llvm::Record *def) : def(def) { // Return the name of this interface. StringRef Interface::getName() const { - return def->getValueAsString("cppClassName"); + return def->getValueAsString("cppInterfaceName"); } // Return the C++ namespace of this interface. diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 4b7fd85227aa..55b1e3947f3b 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -873,38 +873,43 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, addTypeConstraint(result)); } } + + auto shouldBeSkipped = [this](llvm::Record *def) { + return def->isAnonymous() || curDeclScope->lookup(def->getName()) || + def->isSubClassOf("DeclareInterfaceMethods"); + }; + /// Attr constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { - if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { - tblgen::Attribute constraint(def); - decls.push_back( - createODSNativePDLLConstraintDecl( - constraint, convertLocToRange(def->getLoc().front()), attrTy, - constraint.getStorageType())); - } + if (shouldBeSkipped(def)) + continue; + + tblgen::Attribute constraint(def); + decls.push_back(createODSNativePDLLConstraintDecl( + constraint, convertLocToRange(def->getLoc().front()), attrTy, + constraint.getStorageType())); } /// Type constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { - if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { - tblgen::TypeConstraint constraint(def); - decls.push_back( - createODSNativePDLLConstraintDecl( - constraint, convertLocToRange(def->getLoc().front()), typeTy, - constraint.getCPPClassName())); - } - } - /// Interfaces. - ast::Type opTy = ast::OperationType::get(ctx); - for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { - StringRef name = def->getName(); - if (def->isAnonymous() || curDeclScope->lookup(name) || - def->isSubClassOf("DeclareInterfaceMethods")) + if (shouldBeSkipped(def)) continue; + + tblgen::TypeConstraint constraint(def); + decls.push_back(createODSNativePDLLConstraintDecl( + constraint, convertLocToRange(def->getLoc().front()), typeTy, + constraint.getCPPClassName())); + } + /// OpInterfaces. + ast::Type opTy = ast::OperationType::get(ctx); + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) { + if (shouldBeSkipped(def)) + continue; + SMRange loc = convertLocToRange(def->getLoc().front()); std::string cppClassName = llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), - def->getValueAsString("cppClassName")) + def->getValueAsString("cppInterfaceName")) .str(); std::string codeBlock = llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", @@ -913,18 +918,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, std::string desc = processAndFormatDoc(def->getValueAsString("description")); - if (def->isSubClassOf("OpInterface")) { - decls.push_back(createODSNativePDLLConstraintDecl( - name, codeBlock, loc, opTy, cppClassName, desc)); - } else if (def->isSubClassOf("AttrInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, attrTy, cppClassName, desc)); - } else if (def->isSubClassOf("TypeInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, typeTy, cppClassName, desc)); - } + decls.push_back(createODSNativePDLLConstraintDecl( + def->getName(), codeBlock, loc, opTy, cppClassName, desc)); } } diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll index f90f7ab8a412..5526aa852482 100644 --- a/mlir/test/mlir-pdll/Parser/include_td.pdll +++ b/mlir/test/mlir-pdll/Parser/include_td.pdll @@ -32,21 +32,21 @@ // CHECK-NEXT: CppClass: ::mlir::IntegerType // CHECK-NEXT: } -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-AttrConstraintDecl +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl {{.*}} + // CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-OpConstraintDecl // CHECK: `-OpNameDecl - -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> -// CHECK: `Inputs` -// CHECK: `-VariableDecl {{.*}} Name Type -// CHECK: `Constraints` -// CHECK: `-TypeConstraintDecl {{.*}} diff --git a/mlir/test/mlir-tblgen/interfaces-as-constraints.td b/mlir/test/mlir-tblgen/interfaces-as-constraints.td new file mode 100644 index 000000000000..5963dd8bb8ac --- /dev/null +++ b/mlir/test/mlir-tblgen/interfaces-as-constraints.td @@ -0,0 +1,47 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def TopLevelTypeInterface : TypeInterface<"TopLevelTypeInterface">; + +def TypeInterfaceInNamespace : TypeInterface<"TypeInterfaceInNamespace"> { + let cppNamespace = "test"; +} + +def TopLevelAttrInterface : AttrInterface<"TopLevelAttrInterface">; + +def AttrInterfaceInNamespace : AttrInterface<"AttrInterfaceInNamespace"> { + let cppNamespace = "test"; +} + +def OpUsingAllOfThose : Op { + let arguments = (ins TopLevelAttrInterface:$attr1, AttrInterfaceInNamespace:$attr2); + let results = (outs TopLevelTypeInterface:$res1, TypeInterfaceInNamespace:$res2); +} + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}( +// CHECK: if (!((type.isa()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be TopLevelTypeInterface instance, but got " << type; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}( +// CHECK: if (!((type.isa()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be TypeInterfaceInNamespace instance, but got " << type; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}( +// CHECK: if (attr && !((attr.isa()))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: TopLevelAttrInterface instance"; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}( +// CHECK: if (attr && !((attr.isa()))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: AttrInterfaceInNamespace instance"; + +// CHECK: TopLevelAttrInterface OpUsingAllOfThose::attr1() +// CHECK: test::AttrInterfaceInNamespace OpUsingAllOfThose::attr2() diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 212fe0e1204e..54190aac15fa 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2304,7 +2304,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc, // DeclareOpInterfaceMethods // and the like. // TODO: Add hasCppInterface check. - if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (auto name = def.getValueAsOptionalString("cppInterfaceName")) { if (*name == "InferTypeOpInterface" && def.getValueAsString("cppNamespace") == "::mlir") canInferResultTypes = true;