From 8165f181d9a1ff503919d3625f6a48955a226b3c Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 19 Aug 2019 12:43:46 -0700 Subject: [PATCH] Add support for Operation interfaces. Operation interfaces, as the name suggests, are those registered at the Operation level. These interfaces provide an opaque view into derived operations, by providing a virtual interface that must be implemented. As an example, the Linalg dialect implements an interface LinalgOp that provides general queries about some of the dialects library operations. These queries may provide things like: the number of parallel loops, the number of inputs and outputs, etc. Operation interfaces are defined by overriding the CRTP base class OpInterface. This class takes as a template parameter, a `Traits` class that defines a Concept and a Model class. These classes provide an implementation of concept-based polymorphism, where the Concept defines a set of virtual methods that are overridden by the Model that is templated on the concrete operation type. It is important to note that these classes should be pure in that they contain no non-static data members. PiperOrigin-RevId: 264218741 --- mlir/g3doc/Interfaces.md | 70 +++++++ mlir/include/mlir/IR/OpBase.td | 13 ++ mlir/include/mlir/IR/OpDefinition.h | 113 ++++++++++- mlir/include/mlir/IR/OperationSupport.h | 19 +- .../mlir/Linalg/IR/LinalgLibraryOps.td | 7 +- mlir/include/mlir/Linalg/IR/LinalgOps.h | 181 +++++++----------- 6 files changed, 287 insertions(+), 116 deletions(-) diff --git a/mlir/g3doc/Interfaces.md b/mlir/g3doc/Interfaces.md index b1f823aac7ac..666654ae26c7 100644 --- a/mlir/g3doc/Interfaces.md +++ b/mlir/g3doc/Interfaces.md @@ -101,3 +101,73 @@ InlinerInterface interface(ctx); if(!interface.isLegalToInline(...)) ... ``` + +### Operation Interfaces + +Operation interfaces, as the name suggests, are those registered at the +Operation level. These interfaces provide an opaque view into derived +operations, by providing a virtual interface that must be implemented. As an +example, the `Linalg` dialect may implement an interface that provides general +queries about some of the dialects library operations. These queries may provide +things like: the number of parallel loops, the number of inputs and outputs, +etc. + +Operation interfaces are defined by overriding the CRTP base class +`OpInterface`. This class takes as a template parameter, a `Traits` class that +defines a `Concept` and a `Model` class. These classes provide an implementation +of concept-based polymorphism, where the Concept defines a set of virtual +methods that are overridden by the Model that is templated on the concrete +operation type. It is important to note that these classes should be pure in +that they contain no non-static data members. Operations that wish to override +this interface should add the provided trait `OpInterface<..>::Trait` upon +registration. + +```c++ +struct ExampleOpInterfaceTraits { +/// Define a base concept class that defines the virtual interface that needs +/// to be overridden. +struct Concept { + virtual ~Concept(); + virtual unsigned getNumInputs(Operation *op) = 0; +}; + +/// Define a model class that specializes a concept on a given operation type. +template +struct Model { + /// Override the method to dispatch on the concrete operation. + unsigned getNumInputs(Operation *op) final { + return llvm::cast(op).getNumInputs(); + } +}; +}; + +class ExampleOpInterface : public OpInterface { +public: + /// The interface dispatches to 'getImpl()', an instance of the concept. + unsigned getNumInputs() { + return getImpl()->getNumInputs(getOperation()); + } +}; + +``` + +Once the interface has been defined, it is registered to an operation by adding +the provided trait `ExampleOpInterface::Trait`. Using this interface is just +like using any other derived operation type, i.e. casting: + +```c++ +/// When defining the operation, the interface is registered via the nested +/// 'Trait' class provided by the 'OpInterface<>' base class. +class MyOp : public Op { +public: + /// The definition of the interface method on the derived operation. + unsigned getNumInputs() { return ...; } +}; + +/// Later, we can query if a specific operation(like 'MyOp') overrides the given +/// interface. +Operation *op = ...; +if (ExampleOpInterface example = dyn_cast(op)) + llvm::errs() << "num inputs = " << example.getNumInputs() << "\n"; +``` diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index eb49c237bb21..f1349799dc8c 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1069,6 +1069,19 @@ def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; // to have the same array size. def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; +//===----------------------------------------------------------------------===// +// OpInterface definitions +//===----------------------------------------------------------------------===// + +// NativeOpInterface corresponds to a specific 'OpInterface' class defined in +// C++. The purpose to wrap around C++ symbol string with this class is to make +// interfaces specified for ops in TableGen less alien and more integrated. +class NativeOpInterface : NativeOpTrait<""> { + // TODO(riverriddle) Remove when operation interfaces have their own trait + // subclass. + let trait = prop # "::Trait"; +} + //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index ed68936d5062..fd3526285269 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -996,10 +996,121 @@ private: traitID); } - /// Allow access to 'hasTrait'. + /// Returns an opaque pointer to a concept instance of the interface with the + /// given ID if one was registered to this operation. + static void *getRawInterface(ClassID *id) { + return InterfaceLookup::template lookup...>(id); + } + + struct InterfaceLookup { + /// Trait to check if T provides a static 'getInterfaceID' method. + template + using has_get_interface_id = decltype(T::getInterfaceID()); + + /// If 'T' is the same interface as 'interfaceID' return the concept + /// instance. + template + static typename std::enable_if::value, + void *>::type + lookup(ClassID *interfaceID) { + return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr; + } + + /// 'T' is known to not be an interface, return nullptr. + template + static typename std::enable_if::value, + void *>::type + lookup(ClassID *) { + return nullptr; + } + + template + static void *lookup(ClassID *interfaceID) { + auto *concept = lookup(interfaceID); + return concept ? concept : lookup(interfaceID); + } + }; + + /// Allow access to 'hasTrait' and 'getRawInterface'. friend AbstractOperation; }; +/// This class represents the base of an operation interface. Operation +/// interfaces provide access to derived *Op properties through an opaquely +/// Operation instance. Derived interfaces must also provide a 'Traits' class +/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an +/// abstract virtual interface, where as the 'Model' class implements this +/// interface for a specific derived *Op type. Both of these classes *must* not +/// contain non-static data. A simple example is shown below: +/// +/// struct ExampleOpInterfaceTraits { +/// struct Concept { +/// virtual unsigned getNumInputs(Operation *op) = 0; +/// }; +/// template class Model { +/// unsigned getNumInputs(Operation *op) final { +/// return llvm::cast(op).getNumInputs(); +/// } +/// }; +/// }; +/// +template +class OpInterface : public Op { +public: + using Concept = typename Traits::Concept; + template using Model = typename Traits::template Model; + + OpInterface(Operation *op = nullptr) + : Op(op), impl(op ? getInterfaceFor(op) : nullptr) { + assert((!op || impl) && + "instantiating an interface with an unregistered operation"); + } + + /// Support 'classof' by checking if the given operation defines the concrete + /// interface. + static bool classof(Operation *op) { return getInterfaceFor(op); } + + /// Define an accessor for the ID of this interface. + static ClassID *getInterfaceID() { return ClassID::getID(); } + + /// This is a special trait that registers a given interface with an + /// operation. + template + struct Trait : public OpTrait::TraitBase { + /// Define an accessor for the ID of this interface. + static ClassID *getInterfaceID() { return ClassID::getID(); } + + /// Provide an accessor to a static instance of the interface model for the + /// concrete operation type. + /// The implementation is inspired from Sean Parent's concept-based + /// polymorphism. A key difference is that the set of classes erased is + /// statically known, which alleviates the need for using dynamic memory + /// allocation. + /// We use a zero-sized templated class `Model` to emit the + /// virtual table and generate a singleton object for each instantiation of + /// this class. + static Concept &instance() { + static Model singleton; + return singleton; + } + }; + +protected: + /// Get the raw concept in the correct derived concept type. + Concept *getImpl() { return impl; } + +private: + /// Returns the impl interface instance for the given operation. + static Concept *getInterfaceFor(Operation *op) { + // Access the raw interface from the abstract operation. + auto *abstractOp = op->getAbstractOperation(); + return abstractOp ? abstractOp->getInterface() : nullptr; + } + + /// A pointer to the impl concept object. + Concept *impl; +}; + // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. namespace impl { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 204da29b39ad..4871c856e98e 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -140,6 +140,14 @@ public: return opProperties & static_cast(property); } + /// Returns an instance of the concept object for the given interface if it + /// was registered to this operation, null otherwise. This should not be used + /// directly. + template typename T::Concept *getInterface() const { + return reinterpret_cast( + getRawInterface(T::getInterfaceID())); + } + /// Returns if the operation has a particular trait. template