diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index ff7192c71c26..d852de7477b9 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -330,6 +330,10 @@ An `InterfaceMethod` is comprised of the following components: - In non-static methods, a variable 'ConcreteOp op' is defined and may be used to refer to an instance of the derived operation. +ODS also allows generating the declarations for the `InterfaceMethod` of the op +if one specifies the interface with `DeclareOpInterfaceMethods` (see example +below). + Examples: ```tablegen @@ -369,6 +373,13 @@ def MyInterface : OpInterface<"MyInterface"> { }]>, ]; } + +// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This +// would result in autogenerating declarations for members `foo`, `bar` and +// `fooStatic`. Methods without bodies are not declared inside the op +// declaration but instead handled by the op interface trait directly. +def OpWithInferTypeInterfaceOp : Op<... + [DeclareOpInterfaceMethods]> { ... } ``` ### Custom builder methods diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 89609ed3fd03..c662576fdb26 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1279,16 +1279,16 @@ class InterfaceMethod : OpInterfaceTrait { // The name given to the c++ interface class. string cppClassName = name; - /// The list of methods defined by this interface. + // The list of methods defined by this interface. list methods = []; } +// Whether to declare the op interface methods in the op's header. This class +// simply wraps an OpInterface but is used to indicate that the method +// declarations should be generated. +class DeclareOpInterfaceMethods : + OpInterface { + let description = interface.description; + let cppClassName = interface.cppClassName; + let methods = interface.methods; +} + //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h new file mode 100644 index 000000000000..46f43c683cc9 --- /dev/null +++ b/mlir/include/mlir/TableGen/OpInterfaces.h @@ -0,0 +1,105 @@ +//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// OpInterfaces wrapper to simplify using TableGen OpInterfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_OPINTERFACES_H_ +#define MLIR_TABLEGEN_OPINTERFACES_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Init; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class with helper methods for accessing OpInterfaceMethod defined +// in TableGen. +class OpInterfaceMethod { +public: + // This struct represents a single method argument. + struct Argument { + StringRef type; + StringRef name; + }; + + explicit OpInterfaceMethod(const llvm::Record *def); + + // Return the return type of this method. + StringRef getReturnType() const; + + // Return the name of this method. + StringRef getName() const; + + // Return if this method is static. + bool isStatic() const; + + // Return the body for this method if it has one. + llvm::Optional getBody() const; + + // Return the description of this method if it has one. + llvm::Optional getDescription() const; + + // Arguments. + ArrayRef getArguments() const; + bool arg_empty() const; + +private: + // The TableGen definition of this method. + const llvm::Record *def; + + // The arguments of this method. + SmallVector arguments; +}; + +//===----------------------------------------------------------------------===// +// OpInterface +//===----------------------------------------------------------------------===// + +// Wrapper class with helper methods for accessing OpInterfaces defined in +// TableGen. +class OpInterface { +public: + explicit OpInterface(const llvm::Record *def); + + // Return the name of this interface. + StringRef getName() const; + + // Return the methods of this interface. + ArrayRef getMethods() const; + + // Return the description of this method if it has one. + llvm::Optional getDescription() const; + +private: + // The TableGen definition of this interface. + const llvm::Record *def; + + // The methods of this interface. + SmallVector methods; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_OPINTERFACES_H_ diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h index 8a3463d257e4..cfa1d93951d2 100644 --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -33,6 +33,8 @@ class Record; namespace mlir { namespace tblgen { +class OpInterface; + // Wrapper class with helper methods for accessing OpTrait constraints defined // in TableGen. class OpTrait { @@ -44,7 +46,9 @@ public: // OpTrait corresponding to predicate on operation. Pred, // OpTrait controlling op definition generator internals. - Internal + Internal, + // OpTrait corresponding to OpInterface. + Interface }; explicit OpTrait(Kind kind, const llvm::Record *def); @@ -92,6 +96,23 @@ public: } }; +// OpTrait corresponding to an OpInterface on the operation. +class InterfaceOpTrait : public OpTrait { +public: + // Returns member function defitions corresponding to the trait, + OpInterface getOpInterface() const; + + // Returns the trait corresponding to a C++ trait class. + StringRef getTrait() const; + + static bool classof(const OpTrait *t) { + return t->getKind() == Kind::Interface; + } + + // Whether the declaration of methods for this trait should be emitted. + bool shouldDeclareMethods() const; +}; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp new file mode 100644 index 000000000000..e4e80e06676d --- /dev/null +++ b/mlir/lib/TableGen/OpInterfaces.cpp @@ -0,0 +1,90 @@ +//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// OpInterfaces wrapper to simplify using TableGen OpInterfaces. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/OpInterfaces.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) { + llvm::DagInit *args = def->getValueAsDag("arguments"); + for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { + arguments.push_back( + {llvm::cast(args->getArg(i))->getValue(), + args->getArgNameStr(i)}); + } +} + +StringRef OpInterfaceMethod::getReturnType() const { + return def->getValueAsString("returnType"); +} + +// Return the name of this method. +StringRef OpInterfaceMethod::getName() const { + return def->getValueAsString("name"); +} + +// Return if this method is static. +bool OpInterfaceMethod::isStatic() const { + return def->isSubClassOf("StaticInterfaceMethod"); +} + +// Return the body for this method if it has one. +llvm::Optional OpInterfaceMethod::getBody() const { + auto value = def->getValueAsString("body"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the description of this method if it has one. +llvm::Optional OpInterfaceMethod::getDescription() const { + auto value = def->getValueAsString("description"); + return value.empty() ? llvm::Optional() : value; +} + +ArrayRef OpInterfaceMethod::getArguments() const { + return arguments; +} + +bool OpInterfaceMethod::arg_empty() const { return arguments.empty(); } + +OpInterface::OpInterface(const llvm::Record *def) : def(def) { + auto *listInit = dyn_cast(def->getValueInit("methods")); + for (llvm::Init *init : listInit->getValues()) + methods.emplace_back(cast(init)->getDef()); +} + +// Return the name of this interface. +StringRef OpInterface::getName() const { + return def->getValueAsString("cppClassName"); +} + +// Return the methods of this interface. +ArrayRef OpInterface::getMethods() const { return methods; } + +// Return the description of this method if it has one. +llvm::Optional OpInterface::getDescription() const { + auto value = def->getValueAsString("description"); + return value.empty() ? llvm::Optional() : value; +} diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp index 0a357acb7440..0e436a874974 100644 --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -20,6 +20,8 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/OpTrait.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/Predicate.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -27,33 +29,47 @@ #include "llvm/TableGen/Record.h" using namespace mlir; +using namespace mlir::tblgen; -mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) { +OpTrait OpTrait::create(const llvm::Init *init) { auto def = cast(init)->getDef(); if (def->isSubClassOf("PredOpTrait")) return OpTrait(Kind::Pred, def); if (def->isSubClassOf("GenInternalOpTrait")) return OpTrait(Kind::Internal, def); + if (def->isSubClassOf("OpInterface")) + return OpTrait(Kind::Interface, def); assert(def->isSubClassOf("NativeOpTrait")); return OpTrait(Kind::Native, def); } -mlir::tblgen::OpTrait::OpTrait(Kind kind, const llvm::Record *def) - : def(def), kind(kind) {} +OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} -llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const { +llvm::StringRef NativeOpTrait::getTrait() const { return def->getValueAsString("trait"); } -llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const { +llvm::StringRef InternalOpTrait::getTrait() const { return def->getValueAsString("trait"); } -std::string mlir::tblgen::PredOpTrait::getPredTemplate() const { +std::string PredOpTrait::getPredTemplate() const { auto pred = tblgen::Pred(def->getValueInit("predicate")); return pred.getCondition(); } -llvm::StringRef mlir::tblgen::PredOpTrait::getDescription() const { +llvm::StringRef PredOpTrait::getDescription() const { return def->getValueAsString("description"); } + +OpInterface InterfaceOpTrait::getOpInterface() const { + return OpInterface(def); +} + +llvm::StringRef InterfaceOpTrait::getTrait() const { + return def->getValueAsString("trait"); +} + +bool InterfaceOpTrait::shouldDeclareMethods() const { + return def->isSubClassOf("DeclareOpInterfaceMethods"); +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 72991ced497b..e419b7ef3b12 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -330,15 +330,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { } def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", - [InferTypeOpInterface]> { + [DeclareOpInterfaceMethods]> { let arguments = (ins AnyTensor:$x, AnyTensor:$y); let results = (outs AnyTensor:$res); - // TODO(jpienaar): Remove the need to specify these here. - let extraClassDeclaration = [{ - SmallVector inferReturnTypes(llvm::Optional location, - ArrayRef operands, ArrayRef attributes, - ArrayRef regions); - }]; } def IsNotScalar : Constraint>; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index e3d70a1c669c..089a49229532 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -23,6 +23,7 @@ #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" @@ -541,6 +542,9 @@ private: // Generates the traits used by the object. void genTraits(); + // Generate the OpInterface methods. + void genOpInterfaceMethods(); + private: // The TableGen record for this op. // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, @@ -577,6 +581,7 @@ OpEmitter::OpEmitter(const Operator &op) genVerifier(); genCanonicalizerDecls(); genFolderDecls(); + genOpInterfaceMethods(); } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { @@ -1071,6 +1076,30 @@ void OpEmitter::genFolderDecls() { } } +void OpEmitter::genOpInterfaceMethods() { + for (const auto &trait : op.getTraits()) { + auto opTrait = dyn_cast(&trait); + if (!opTrait || !opTrait->shouldDeclareMethods()) + continue; + auto interface = opTrait->getOpInterface(); + for (auto method : interface.getMethods()) { + // Don't declare if the method has a body. + if (method.getBody()) + continue; + std::string args; + llvm::raw_string_ostream os(args); + mlir::interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + opClass.newMethod(method.getReturnType(), method.getName(), os.str(), + method.isStatic() ? OpMethod::MP_Static + : OpMethod::MP_None, + /*declOnly=*/true); + } + } +} + void OpEmitter::genParser() { if (!hasStringAttribute(def, "parser")) return; @@ -1286,6 +1315,8 @@ void OpEmitter::genTraits() { for (const auto &trait : op.getTraits()) { if (auto opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getTrait()); + else if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getTrait()); } // Add variadic size trait and normal op traits. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 7961b6f6f8af..4da412c2f438 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -22,6 +22,7 @@ #include "DocGenUtilities.h" #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/OpInterfaces.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -32,98 +33,8 @@ using namespace llvm; using namespace mlir; - -namespace { -//===----------------------------------------------------------------------===// -// OpInterfaceMethod -//===----------------------------------------------------------------------===// - -// This struct represents a single method argument. -struct MethodArgument { - StringRef type, name; -}; - -// Wrapper class around a single interface method. -class OpInterfaceMethod { -public: - explicit OpInterfaceMethod(const llvm::Record *def) : def(def) { - llvm::DagInit *args = def->getValueAsDag("arguments"); - for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { - arguments.push_back( - {llvm::cast(args->getArg(i))->getValue(), - args->getArgNameStr(i)}); - } - } - - // Return the return type of this method. - StringRef getReturnType() const { - return def->getValueAsString("returnType"); - } - - // Return the name of this method. - StringRef getName() const { return def->getValueAsString("name"); } - - // Return if this method is static. - bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); } - - // Return the body for this method if it has one. - llvm::Optional getBody() const { - auto value = def->getValueAsString("body"); - return value.empty() ? llvm::Optional() : value; - } - - // Return the description of this method if it has one. - llvm::Optional getDescription() const { - auto value = def->getValueAsString("description"); - return value.empty() ? llvm::Optional() : value; - } - - // Arguments. - ArrayRef getArguments() const { return arguments; } - bool arg_empty() const { return arguments.empty(); } - -protected: - // The TableGen definition of this method. - const llvm::Record *def; - - // The arguments of this method. - SmallVector arguments; -}; - -//===----------------------------------------------------------------------===// -// OpInterface -//===----------------------------------------------------------------------===// - -// Wrapper class with helper methods for accessing OpInterfaces defined in -// TableGen. -class OpInterface { -public: - explicit OpInterface(const llvm::Record *def) : def(def) { - auto *listInit = dyn_cast(def->getValueInit("methods")); - for (llvm::Init *init : listInit->getValues()) - methods.emplace_back(cast(init)->getDef()); - } - - // Return the name of this interface. - StringRef getName() const { return def->getValueAsString("cppClassName"); } - - // Return the methods of this interface. - ArrayRef getMethods() const { return methods; } - - // Return the description of this method if it has one. - llvm::Optional getDescription() const { - auto value = def->getValueAsString("description"); - return value.empty() ? llvm::Optional() : value; - } - -protected: - // The TableGen definition of this interface. - const llvm::Record *def; - - // The methods of this interface. - SmallVector methods; -}; -} // end anonymous namespace +using mlir::tblgen::OpInterface; +using mlir::tblgen::OpInterfaceMethod; // Emit the method name and argument list for the given method. If // 'addOperationArg' is true, then an Operation* argument is added to the @@ -133,9 +44,10 @@ static void emitMethodNameAndArgs(const OpInterfaceMethod &method, os << method.getName() << '('; if (addOperationArg) os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", "); - interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) { - os << arg.type << " " << arg.name; - }); + interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); os << ')'; } @@ -155,8 +67,9 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) { os << " {\n return getImpl()->" << method.getName() << '('; if (!method.isStatic()) os << "getOperation()" << (method.arg_empty() ? "" : ", "); - interleaveComma(method.getArguments(), os, - [&](const MethodArgument &arg) { os << arg.name; }); + interleaveComma( + method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; } } @@ -218,8 +131,9 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) { // Add the arguments to the call. os << method.getName() << '('; - interleaveComma(method.getArguments(), os, - [&](const MethodArgument &arg) { os << arg.name; }); + interleaveComma( + method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; } os << " };\n"; @@ -294,9 +208,10 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) { if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os) << method.getName() << '('; - interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) { - emitCPPType(arg.type, os) << arg.name; - }); + interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + emitCPPType(arg.type, os) << arg.name; + }); os << ");\n```\n"; // Emit the description.