From 9f186bb125d697786066f1fdd1d0c0e0479a3a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Wed, 6 Jul 2022 16:42:21 +0200 Subject: [PATCH] [mlir][ods] Make Type- and AttrInterfaces also `Type`s and `Attr`s By making TypeInterfaces and AttrInterfaces, Types and Attrs respectively it'd then be possible to use them anywhere where a Type or Attr may go. That is within the arguments and results of an Op definition, in a RewritePattern etc. Prior to this change users had to separately define a Type or Attr, with a predicate to check whether a type or attribute implements a given interface. Such code will be redundant now. Removing such occurrences in upstream dialects will be part of a separate patch. As part of implementing this patch, slight refactoring had to be done. In particular, Interfaces cppClassName field was renamed to cppInterfaceName as it "clashed" with TypeConstraints cppClassName. In particular Interfaces cppClassName expected just the class name, without any namespaces, while TypeConstraints cppClassName expected a fully qualified class name. Differential Revision: https://reviews.llvm.org/D129209 --- mlir/docs/PDLL.md | 2 +- mlir/include/mlir/IR/OpBase.td | 30 ++++++--- mlir/lib/TableGen/Interfaces.cpp | 2 +- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 63 +++++++++---------- mlir/test/mlir-pdll/Parser/include_td.pdll | 14 ++--- .../mlir-tblgen/interfaces-as-constraints.td | 47 ++++++++++++++ mlir/tools/mlir-tblgen/OpFormatGen.cpp | 2 +- 7 files changed, 107 insertions(+), 53 deletions(-) create mode 100644 mlir/test/mlir-tblgen/interfaces-as-constraints.td diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md index 2aadbb6035d0..340940f38547 100644 --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1225,7 +1225,7 @@ was imported: - Imported `Type` constraints utilize the `cppClassName` field for native type translation. * `AttrInterface`/`OpInterface`/`TypeInterface` constraints - - Imported interfaces utilize the `cppClassName` field for native type translation. + - Imported interfaces utilize the `cppInterfaceName` field for native type translation. #### Defining Constraints Inline diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 89c7122b5ba7..16174454b7a1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1949,7 +1949,7 @@ class Interface { string description = ""; // The name given to the c++ interface class. - string cppClassName = name; + string cppInterfaceName = name; // The C++ namespace that this interface should be placed into. // @@ -1970,13 +1970,25 @@ class Interface { } // AttrInterface represents an interface registered to an attribute. -class AttrInterface : Interface, InterfaceTrait; +class AttrInterface : Interface, InterfaceTrait, + Attr()">, + name # " instance"> +{ + let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name; + let returnType = storageType; + let convertFromStorage = "$_self"; +} // OpInterface represents an interface registered to an operation. class OpInterface : Interface, OpInterfaceTrait; // TypeInterface represents an interface registered to a type. -class TypeInterface : Interface, InterfaceTrait; +class TypeInterface : Interface, InterfaceTrait, + Type()">, + name # " instance", + !if(!empty(cppNamespace),"", cppNamespace # "::") # name>; // Whether to declare the interface methods in the user entity's header. This // class simply wraps an Interface but is used to indicate that the method @@ -1992,27 +2004,27 @@ class DeclareInterfaceMethods overridenMethods = []> { class DeclareAttrInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - AttrInterface { + AttrInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareOpInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - OpInterface { + OpInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareTypeInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - TypeInterface { + TypeInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index 4d72ceeb45fc..1ee0b140756f 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -81,7 +81,7 @@ Interface::Interface(const llvm::Record *def) : def(def) { // Return the name of this interface. StringRef Interface::getName() const { - return def->getValueAsString("cppClassName"); + return def->getValueAsString("cppInterfaceName"); } // Return the C++ namespace of this interface. diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 4b7fd85227aa..55b1e3947f3b 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -873,38 +873,43 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, addTypeConstraint(result)); } } + + auto shouldBeSkipped = [this](llvm::Record *def) { + return def->isAnonymous() || curDeclScope->lookup(def->getName()) || + def->isSubClassOf("DeclareInterfaceMethods"); + }; + /// Attr constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { - if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { - tblgen::Attribute constraint(def); - decls.push_back( - createODSNativePDLLConstraintDecl( - constraint, convertLocToRange(def->getLoc().front()), attrTy, - constraint.getStorageType())); - } + if (shouldBeSkipped(def)) + continue; + + tblgen::Attribute constraint(def); + decls.push_back(createODSNativePDLLConstraintDecl( + constraint, convertLocToRange(def->getLoc().front()), attrTy, + constraint.getStorageType())); } /// Type constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { - if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { - tblgen::TypeConstraint constraint(def); - decls.push_back( - createODSNativePDLLConstraintDecl( - constraint, convertLocToRange(def->getLoc().front()), typeTy, - constraint.getCPPClassName())); - } - } - /// Interfaces. - ast::Type opTy = ast::OperationType::get(ctx); - for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { - StringRef name = def->getName(); - if (def->isAnonymous() || curDeclScope->lookup(name) || - def->isSubClassOf("DeclareInterfaceMethods")) + if (shouldBeSkipped(def)) continue; + + tblgen::TypeConstraint constraint(def); + decls.push_back(createODSNativePDLLConstraintDecl( + constraint, convertLocToRange(def->getLoc().front()), typeTy, + constraint.getCPPClassName())); + } + /// OpInterfaces. + ast::Type opTy = ast::OperationType::get(ctx); + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) { + if (shouldBeSkipped(def)) + continue; + SMRange loc = convertLocToRange(def->getLoc().front()); std::string cppClassName = llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), - def->getValueAsString("cppClassName")) + def->getValueAsString("cppInterfaceName")) .str(); std::string codeBlock = llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", @@ -913,18 +918,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, std::string desc = processAndFormatDoc(def->getValueAsString("description")); - if (def->isSubClassOf("OpInterface")) { - decls.push_back(createODSNativePDLLConstraintDecl( - name, codeBlock, loc, opTy, cppClassName, desc)); - } else if (def->isSubClassOf("AttrInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, attrTy, cppClassName, desc)); - } else if (def->isSubClassOf("TypeInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, typeTy, cppClassName, desc)); - } + decls.push_back(createODSNativePDLLConstraintDecl( + def->getName(), codeBlock, loc, opTy, cppClassName, desc)); } } diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll index f90f7ab8a412..5526aa852482 100644 --- a/mlir/test/mlir-pdll/Parser/include_td.pdll +++ b/mlir/test/mlir-pdll/Parser/include_td.pdll @@ -32,21 +32,21 @@ // CHECK-NEXT: CppClass: ::mlir::IntegerType // CHECK-NEXT: } -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-AttrConstraintDecl +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl {{.*}} + // CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-OpConstraintDecl // CHECK: `-OpNameDecl - -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> -// CHECK: `Inputs` -// CHECK: `-VariableDecl {{.*}} Name Type -// CHECK: `Constraints` -// CHECK: `-TypeConstraintDecl {{.*}} diff --git a/mlir/test/mlir-tblgen/interfaces-as-constraints.td b/mlir/test/mlir-tblgen/interfaces-as-constraints.td new file mode 100644 index 000000000000..5963dd8bb8ac --- /dev/null +++ b/mlir/test/mlir-tblgen/interfaces-as-constraints.td @@ -0,0 +1,47 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def TopLevelTypeInterface : TypeInterface<"TopLevelTypeInterface">; + +def TypeInterfaceInNamespace : TypeInterface<"TypeInterfaceInNamespace"> { + let cppNamespace = "test"; +} + +def TopLevelAttrInterface : AttrInterface<"TopLevelAttrInterface">; + +def AttrInterfaceInNamespace : AttrInterface<"AttrInterfaceInNamespace"> { + let cppNamespace = "test"; +} + +def OpUsingAllOfThose : Op { + let arguments = (ins TopLevelAttrInterface:$attr1, AttrInterfaceInNamespace:$attr2); + let results = (outs TopLevelTypeInterface:$res1, TypeInterfaceInNamespace:$res2); +} + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}( +// CHECK: if (!((type.isa()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be TopLevelTypeInterface instance, but got " << type; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}( +// CHECK: if (!((type.isa()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex +// CHECK-NEXT: << " must be TypeInterfaceInNamespace instance, but got " << type; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}( +// CHECK: if (attr && !((attr.isa()))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: TopLevelAttrInterface instance"; + +// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}( +// CHECK: if (attr && !((attr.isa()))) { +// CHECK-NEXT: return op->emitOpError("attribute '") << attrName +// CHECK-NEXT: << "' failed to satisfy constraint: AttrInterfaceInNamespace instance"; + +// CHECK: TopLevelAttrInterface OpUsingAllOfThose::attr1() +// CHECK: test::AttrInterfaceInNamespace OpUsingAllOfThose::attr2() diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 212fe0e1204e..54190aac15fa 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2304,7 +2304,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc, // DeclareOpInterfaceMethods // and the like. // TODO: Add hasCppInterface check. - if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (auto name = def.getValueAsOptionalString("cppInterfaceName")) { if (*name == "InferTypeOpInterface" && def.getValueAsString("cppNamespace") == "::mlir") canInferResultTypes = true;