Enable autogenerating OpInterface method declarations

Add DeclareOpInterfaceFunctions to enable specifying whether OpInterfaceMethods
for an OpInterface should be generated automatically. This avoids needing to
declare the extra methods, while also allowing adding function declaration by way of trait/inheritance.

Most of this change is mechanical/extracting classes to be reusable.

PiperOrigin-RevId: 272042739
This commit is contained in:
Jacques Pienaar 2019-09-30 12:42:31 -07:00 committed by A. Unique TensorFlower
parent 8e67039e31
commit 0b81eb928b
9 changed files with 315 additions and 122 deletions

View File

@ -330,6 +330,10 @@ An `InterfaceMethod` is comprised of the following components:
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
used to refer to an instance of the derived operation.
ODS also allows generating the declarations for the `InterfaceMethod` of the op
if one specifies the interface with `DeclareOpInterfaceMethods` (see example
below).
Examples:
```tablegen
@ -369,6 +373,13 @@ def MyInterface : OpInterface<"MyInterface"> {
}]>,
];
}
// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This
// would result in autogenerating declarations for members `foo`, `bar` and
// `fooStatic`. Methods without bodies are not declared inside the op
// declaration but instead handled by the op interface trait directly.
def OpWithInferTypeInterfaceOp : Op<...
[DeclareOpInterfaceMethods<MyInterface>]> { ... }
```
### Custom builder methods

View File

@ -1279,16 +1279,16 @@ class InterfaceMethod<string desc, string retTy, string methodName,
// A human-readable description of what this method does.
string description = desc;
/// The name of the interface method.
// The name of the interface method.
string name = methodName;
/// The c++ type-name of the return type.
// The c++ type-name of the return type.
string returnType = retTy;
/// A dag of string that correspond to the arguments of the method.
// A dag of string that correspond to the arguments of the method.
dag arguments = args;
/// An optional body to the method.
// An optional body to the method.
code body = methodBody;
}
@ -1305,10 +1305,20 @@ class OpInterface<string name> : OpInterfaceTrait<name> {
// The name given to the c++ interface class.
string cppClassName = name;
/// The list of methods defined by this interface.
// The list of methods defined by this interface.
list<InterfaceMethod> methods = [];
}
// Whether to declare the op interface methods in the op's header. This class
// simply wraps an OpInterface but is used to indicate that the method
// declarations should be generated.
class DeclareOpInterfaceMethods<OpInterface interface> :
OpInterface<interface.cppClassName> {
let description = interface.description;
let cppClassName = interface.cppClassName;
let methods = interface.methods;
}
//===----------------------------------------------------------------------===//
// Op definitions
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,105 @@
//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_OPINTERFACES_H_
#define MLIR_TABLEGEN_OPINTERFACES_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Init;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing OpInterfaceMethod defined
// in TableGen.
class OpInterfaceMethod {
public:
// This struct represents a single method argument.
struct Argument {
StringRef type;
StringRef name;
};
explicit OpInterfaceMethod(const llvm::Record *def);
// Return the return type of this method.
StringRef getReturnType() const;
// Return the name of this method.
StringRef getName() const;
// Return if this method is static.
bool isStatic() const;
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
// Arguments.
ArrayRef<Argument> getArguments() const;
bool arg_empty() const;
private:
// The TableGen definition of this method.
const llvm::Record *def;
// The arguments of this method.
SmallVector<Argument, 2> arguments;
};
//===----------------------------------------------------------------------===//
// OpInterface
//===----------------------------------------------------------------------===//
// Wrapper class with helper methods for accessing OpInterfaces defined in
// TableGen.
class OpInterface {
public:
explicit OpInterface(const llvm::Record *def);
// Return the name of this interface.
StringRef getName() const;
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> getMethods() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
private:
// The TableGen definition of this interface.
const llvm::Record *def;
// The methods of this interface.
SmallVector<OpInterfaceMethod, 8> methods;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_OPINTERFACES_H_

View File

@ -33,6 +33,8 @@ class Record;
namespace mlir {
namespace tblgen {
class OpInterface;
// Wrapper class with helper methods for accessing OpTrait constraints defined
// in TableGen.
class OpTrait {
@ -44,7 +46,9 @@ public:
// OpTrait corresponding to predicate on operation.
Pred,
// OpTrait controlling op definition generator internals.
Internal
Internal,
// OpTrait corresponding to OpInterface.
Interface
};
explicit OpTrait(Kind kind, const llvm::Record *def);
@ -92,6 +96,23 @@ public:
}
};
// OpTrait corresponding to an OpInterface on the operation.
class InterfaceOpTrait : public OpTrait {
public:
// Returns member function defitions corresponding to the trait,
OpInterface getOpInterface() const;
// Returns the trait corresponding to a C++ trait class.
StringRef getTrait() const;
static bool classof(const OpTrait *t) {
return t->getKind() == Kind::Interface;
}
// Whether the declaration of methods for this trait should be emitted.
bool shouldDeclareMethods() const;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -0,0 +1,90 @@
//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpInterfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
}
}
StringRef OpInterfaceMethod::getReturnType() const {
return def->getValueAsString("returnType");
}
// Return the name of this method.
StringRef OpInterfaceMethod::getName() const {
return def->getValueAsString("name");
}
// Return if this method is static.
bool OpInterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
}
// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
auto value = def->getValueAsString("body");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
ArrayRef<OpInterfaceMethod::Argument> OpInterfaceMethod::getArguments() const {
return arguments;
}
bool OpInterfaceMethod::arg_empty() const { return arguments.empty(); }
OpInterface::OpInterface(const llvm::Record *def) : def(def) {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
}
// Return the name of this interface.
StringRef OpInterface::getName() const {
return def->getValueAsString("cppClassName");
}
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> OpInterface::getMethods() const { return methods; }
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterface::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}

View File

@ -20,6 +20,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpTrait.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@ -27,33 +29,47 @@
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) {
OpTrait OpTrait::create(const llvm::Init *init) {
auto def = cast<llvm::DefInit>(init)->getDef();
if (def->isSubClassOf("PredOpTrait"))
return OpTrait(Kind::Pred, def);
if (def->isSubClassOf("GenInternalOpTrait"))
return OpTrait(Kind::Internal, def);
if (def->isSubClassOf("OpInterface"))
return OpTrait(Kind::Interface, def);
assert(def->isSubClassOf("NativeOpTrait"));
return OpTrait(Kind::Native, def);
}
mlir::tblgen::OpTrait::OpTrait(Kind kind, const llvm::Record *def)
: def(def), kind(kind) {}
OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const {
llvm::StringRef NativeOpTrait::getTrait() const {
return def->getValueAsString("trait");
}
llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const {
llvm::StringRef InternalOpTrait::getTrait() const {
return def->getValueAsString("trait");
}
std::string mlir::tblgen::PredOpTrait::getPredTemplate() const {
std::string PredOpTrait::getPredTemplate() const {
auto pred = tblgen::Pred(def->getValueInit("predicate"));
return pred.getCondition();
}
llvm::StringRef mlir::tblgen::PredOpTrait::getDescription() const {
llvm::StringRef PredOpTrait::getDescription() const {
return def->getValueAsString("description");
}
OpInterface InterfaceOpTrait::getOpInterface() const {
return OpInterface(def);
}
llvm::StringRef InterfaceOpTrait::getTrait() const {
return def->getValueAsString("trait");
}
bool InterfaceOpTrait::shouldDeclareMethods() const {
return def->isSubClassOf("DeclareOpInterfaceMethods");
}

View File

@ -330,15 +330,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
[InferTypeOpInterface]> {
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor:$res);
// TODO(jpienaar): Remove the need to specify these here.
let extraClassDeclaration = [{
SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
ArrayRef<Region> regions);
}];
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;

View File

@ -23,6 +23,7 @@
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
@ -541,6 +542,9 @@ private:
// Generates the traits used by the object.
void genTraits();
// Generate the OpInterface methods.
void genOpInterfaceMethods();
private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
@ -577,6 +581,7 @@ OpEmitter::OpEmitter(const Operator &op)
genVerifier();
genCanonicalizerDecls();
genFolderDecls();
genOpInterfaceMethods();
}
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@ -1071,6 +1076,30 @@ void OpEmitter::genFolderDecls() {
}
}
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait);
if (!opTrait || !opTrait->shouldDeclareMethods())
continue;
auto interface = opTrait->getOpInterface();
for (auto method : interface.getMethods()) {
// Don't declare if the method has a body.
if (method.getBody())
continue;
std::string args;
llvm::raw_string_ostream os(args);
mlir::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
method.isStatic() ? OpMethod::MP_Static
: OpMethod::MP_None,
/*declOnly=*/true);
}
}
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser"))
return;
@ -1286,6 +1315,8 @@ void OpEmitter::genTraits() {
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
}
// Add variadic size trait and normal op traits.

View File

@ -22,6 +22,7 @@
#include "DocGenUtilities.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@ -32,98 +33,8 @@
using namespace llvm;
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// OpInterfaceMethod
//===----------------------------------------------------------------------===//
// This struct represents a single method argument.
struct MethodArgument {
StringRef type, name;
};
// Wrapper class around a single interface method.
class OpInterfaceMethod {
public:
explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
}
}
// Return the return type of this method.
StringRef getReturnType() const {
return def->getValueAsString("returnType");
}
// Return the name of this method.
StringRef getName() const { return def->getValueAsString("name"); }
// Return if this method is static.
bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const {
auto value = def->getValueAsString("body");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Arguments.
ArrayRef<MethodArgument> getArguments() const { return arguments; }
bool arg_empty() const { return arguments.empty(); }
protected:
// The TableGen definition of this method.
const llvm::Record *def;
// The arguments of this method.
SmallVector<MethodArgument, 2> arguments;
};
//===----------------------------------------------------------------------===//
// OpInterface
//===----------------------------------------------------------------------===//
// Wrapper class with helper methods for accessing OpInterfaces defined in
// TableGen.
class OpInterface {
public:
explicit OpInterface(const llvm::Record *def) : def(def) {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
}
// Return the name of this interface.
StringRef getName() const { return def->getValueAsString("cppClassName"); }
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
protected:
// The TableGen definition of this interface.
const llvm::Record *def;
// The methods of this interface.
SmallVector<OpInterfaceMethod, 8> methods;
};
} // end anonymous namespace
using mlir::tblgen::OpInterface;
using mlir::tblgen::OpInterfaceMethod;
// Emit the method name and argument list for the given method. If
// 'addOperationArg' is true, then an Operation* argument is added to the
@ -133,9 +44,10 @@ static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
os << method.getName() << '(';
if (addOperationArg)
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
os << arg.type << " " << arg.name;
});
interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ')';
}
@ -155,8 +67,9 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
os << " {\n return getImpl()->" << method.getName() << '(';
if (!method.isStatic())
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os,
[&](const MethodArgument &arg) { os << arg.name; });
interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
}
@ -218,8 +131,9 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
// Add the arguments to the call.
os << method.getName() << '(';
interleaveComma(method.getArguments(), os,
[&](const MethodArgument &arg) { os << arg.name; });
interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
os << " };\n";
@ -294,9 +208,10 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
emitCPPType(arg.type, os) << arg.name;
});
interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os) << arg.name;
});
os << ");\n```\n";
// Emit the description.