Add support for providing a default implementation for an interface method.

This enables providing a default implementation of an interface method. This method is defined on the Trait that is attached to the operation, and thus has all of the same constraints and properties as any other interface method. This allows for interface authors to provide a conservative default implementation for certain methods, without requiring that all users explicitly define it. The default implementation can be specified via the argument directly after the interface method body:

  StaticInterfaceMethod<
    /*desc=*/"Returns whether two array of types are compatible result types for an op.",
    /*retTy=*/"bool",
    /*methodName=*/"isCompatibleReturnTypes",
    /*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
    /*methodBody=*/[{
      return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
    }],
    /*defaultImplementation=*/[{
      /// Returns whether two arrays are equal as strongest check for
      /// compatibility by default.
      return lhs == rhs;
    }]

PiperOrigin-RevId: 286226054
This commit is contained in:
River Riddle 2019-12-18 11:02:35 -08:00 committed by A. Unique TensorFlower
parent d7e2cc9bd1
commit 29807ff5e4
7 changed files with 66 additions and 10 deletions

View File

@ -332,6 +332,13 @@ An `InterfaceMethod` is comprised of the following components:
to the type of the derived operation currently being operated on.
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
used to refer to an instance of the derived operation.
* DefaultImplementation (Optional)
- An optional explicit default implementation of the interface method.
- This method is placed within the `Trait` class that is attached to the
operation. As such, this method has the same characteristics as any
other [`Trait`](Traits.md) method.
- `ConcreteOp` is an implicitly defined typename that can be used to refer
to the type of the derived operation currently being operated on.
ODS also allows generating the declarations for the `InterfaceMethod` of the op
if one specifies the interface with `DeclareOpInterfaceMethods` (see example
@ -374,6 +381,14 @@ def MyInterface : OpInterface<"MyInterface"> {
"unsigned", "getNumInputsAndOutputs", (ins), [{
return op.getNumInputs() + op.getNumOutputs();
}]>,
// Provide only a default definition of the method.
// Note: `ConcreteOp` corresponds to the derived operation typename.
InterfaceMethod<"/*insert doc here*/",
"unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{
ConcreteOp op = cast<ConcreteOp>(getOperation());
return op.getNumInputs() + op.getNumOutputs();
}]>,
];
}

View File

@ -59,15 +59,16 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
/*retTy=*/"bool",
/*methodName=*/"isCompatibleReturnTypes",
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
[{
/*methodBody=*/[{
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
}],
/*defaultImplementation=*/[{
/// Returns whether two arrays are equal as strongest check for
/// compatibility by default.
return lhs == rhs;
}]
>,
];
}
// Default implementations for some of the interface methods above:
// - compatibleReturnTypes returns whether strictly true.
def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@ -1425,7 +1425,8 @@ class OpInterfaceTrait<string name> : NativeOpTrait<""> {
// Note: non-static interface methods have an implicit 'op' parameter
// corresponding to an instance of the derived operation.
class InterfaceMethod<string desc, string retTy, string methodName,
dag args = (ins), code methodBody = [{}]> {
dag args = (ins), code methodBody = [{}],
code defaultImplementation = [{}]> {
// A human-readable description of what this method does.
string description = desc;
@ -1440,12 +1441,17 @@ class InterfaceMethod<string desc, string retTy, string methodName,
// An optional body to the method.
code body = methodBody;
// An optional default implementation of the method.
code defaultBody = defaultImplementation;
}
// This class represents a single static interface method.
class StaticInterfaceMethod<string desc, string retTy, string methodName,
dag args = (ins), code methodBody = [{}]>
: InterfaceMethod<desc, retTy, methodName, args, methodBody>;
dag args = (ins), code methodBody = [{}],
code defaultImplementation = [{}]>
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
defaultImplementation>;
// OpInterface represents an interface regarding an op.
class OpInterface<string name> : OpInterfaceTrait<name> {

View File

@ -58,6 +58,9 @@ public:
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const;
// Return the default implementation for this method if it has one.
llvm::Optional<StringRef> getDefaultImplementation() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;

View File

@ -57,6 +57,12 @@ llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the default implementation for this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDefaultImplementation() const {
auto value = def->getValueAsString("defaultBody");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
auto value = def->getValueAsString("description");

View File

@ -403,8 +403,7 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
InferTypeOpInterfaceDefault]> {
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}

View File

@ -151,6 +151,29 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
os << " };\n";
}
static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("struct Trait : public OpInterface<{0},"
" detail::{1}>::Trait<ConcreteOp> {{\n",
interfaceName, interfaceTraitsName);
// Insert the default implementation for any methods.
for (auto &method : interface.getMethods()) {
auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl)
continue;
os << " " << (method.isStatic() ? "static " : "") << method.getReturnType()
<< " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
os << " {\n" << defaultImpl.getValue() << " }\n";
}
os << " };\n";
}
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
@ -168,6 +191,9 @@ static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
interfaceName, interfaceName, interfaceTraitsName);
// Emit the derived trait for the interface.
emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
// Insert the method declarations.
for (auto &method : interface.getMethods()) {
os << " " << method.getReturnType() << " ";