forked from OSchip/llvm-project
Add support for providing a default implementation for an interface method.
This enables providing a default implementation of an interface method. This method is defined on the Trait that is attached to the operation, and thus has all of the same constraints and properties as any other interface method. This allows for interface authors to provide a conservative default implementation for certain methods, without requiring that all users explicitly define it. The default implementation can be specified via the argument directly after the interface method body: StaticInterfaceMethod< /*desc=*/"Returns whether two array of types are compatible result types for an op.", /*retTy=*/"bool", /*methodName=*/"isCompatibleReturnTypes", /*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs), /*methodBody=*/[{ return ConcreteOp::isCompatibleReturnTypes(lhs, rhs); }], /*defaultImplementation=*/[{ /// Returns whether two arrays are equal as strongest check for /// compatibility by default. return lhs == rhs; }] PiperOrigin-RevId: 286226054
This commit is contained in:
parent
d7e2cc9bd1
commit
29807ff5e4
|
@ -332,6 +332,13 @@ An `InterfaceMethod` is comprised of the following components:
|
|||
to the type of the derived operation currently being operated on.
|
||||
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
|
||||
used to refer to an instance of the derived operation.
|
||||
* DefaultImplementation (Optional)
|
||||
- An optional explicit default implementation of the interface method.
|
||||
- This method is placed within the `Trait` class that is attached to the
|
||||
operation. As such, this method has the same characteristics as any
|
||||
other [`Trait`](Traits.md) method.
|
||||
- `ConcreteOp` is an implicitly defined typename that can be used to refer
|
||||
to the type of the derived operation currently being operated on.
|
||||
|
||||
ODS also allows generating the declarations for the `InterfaceMethod` of the op
|
||||
if one specifies the interface with `DeclareOpInterfaceMethods` (see example
|
||||
|
@ -374,6 +381,14 @@ def MyInterface : OpInterface<"MyInterface"> {
|
|||
"unsigned", "getNumInputsAndOutputs", (ins), [{
|
||||
return op.getNumInputs() + op.getNumOutputs();
|
||||
}]>,
|
||||
|
||||
// Provide only a default definition of the method.
|
||||
// Note: `ConcreteOp` corresponds to the derived operation typename.
|
||||
InterfaceMethod<"/*insert doc here*/",
|
||||
"unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{
|
||||
ConcreteOp op = cast<ConcreteOp>(getOperation());
|
||||
return op.getNumInputs() + op.getNumOutputs();
|
||||
}]>,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -59,15 +59,16 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
|||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isCompatibleReturnTypes",
|
||||
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
|
||||
[{
|
||||
/*methodBody=*/[{
|
||||
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
|
||||
}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Returns whether two arrays are equal as strongest check for
|
||||
/// compatibility by default.
|
||||
return lhs == rhs;
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
// Default implementations for some of the interface methods above:
|
||||
// - compatibleReturnTypes returns whether strictly true.
|
||||
def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
|
||||
|
||||
#endif // MLIR_INFERTYPEOPINTERFACE
|
||||
|
|
|
@ -1425,7 +1425,8 @@ class OpInterfaceTrait<string name> : NativeOpTrait<""> {
|
|||
// Note: non-static interface methods have an implicit 'op' parameter
|
||||
// corresponding to an instance of the derived operation.
|
||||
class InterfaceMethod<string desc, string retTy, string methodName,
|
||||
dag args = (ins), code methodBody = [{}]> {
|
||||
dag args = (ins), code methodBody = [{}],
|
||||
code defaultImplementation = [{}]> {
|
||||
// A human-readable description of what this method does.
|
||||
string description = desc;
|
||||
|
||||
|
@ -1440,12 +1441,17 @@ class InterfaceMethod<string desc, string retTy, string methodName,
|
|||
|
||||
// An optional body to the method.
|
||||
code body = methodBody;
|
||||
|
||||
// An optional default implementation of the method.
|
||||
code defaultBody = defaultImplementation;
|
||||
}
|
||||
|
||||
// This class represents a single static interface method.
|
||||
class StaticInterfaceMethod<string desc, string retTy, string methodName,
|
||||
dag args = (ins), code methodBody = [{}]>
|
||||
: InterfaceMethod<desc, retTy, methodName, args, methodBody>;
|
||||
dag args = (ins), code methodBody = [{}],
|
||||
code defaultImplementation = [{}]>
|
||||
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
|
||||
defaultImplementation>;
|
||||
|
||||
// OpInterface represents an interface regarding an op.
|
||||
class OpInterface<string name> : OpInterfaceTrait<name> {
|
||||
|
|
|
@ -58,6 +58,9 @@ public:
|
|||
// Return the body for this method if it has one.
|
||||
llvm::Optional<StringRef> getBody() const;
|
||||
|
||||
// Return the default implementation for this method if it has one.
|
||||
llvm::Optional<StringRef> getDefaultImplementation() const;
|
||||
|
||||
// Return the description of this method if it has one.
|
||||
llvm::Optional<StringRef> getDescription() const;
|
||||
|
||||
|
|
|
@ -57,6 +57,12 @@ llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
|
|||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the default implementation for this method if it has one.
|
||||
llvm::Optional<StringRef> OpInterfaceMethod::getDefaultImplementation() const {
|
||||
auto value = def->getValueAsString("defaultBody");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the description of this method if it has one.
|
||||
llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
|
||||
auto value = def->getValueAsString("description");
|
||||
|
|
|
@ -403,8 +403,7 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
|
|||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
InferTypeOpInterfaceDefault]> {
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
|
|
@ -151,6 +151,29 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
|
|||
os << " };\n";
|
||||
}
|
||||
|
||||
static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
||||
StringRef interfaceName,
|
||||
StringRef interfaceTraitsName) {
|
||||
os << " template <typename ConcreteOp>\n "
|
||||
<< llvm::formatv("struct Trait : public OpInterface<{0},"
|
||||
" detail::{1}>::Trait<ConcreteOp> {{\n",
|
||||
interfaceName, interfaceTraitsName);
|
||||
|
||||
// Insert the default implementation for any methods.
|
||||
for (auto &method : interface.getMethods()) {
|
||||
auto defaultImpl = method.getDefaultImplementation();
|
||||
if (!defaultImpl)
|
||||
continue;
|
||||
|
||||
os << " " << (method.isStatic() ? "static " : "") << method.getReturnType()
|
||||
<< " ";
|
||||
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
||||
os << " {\n" << defaultImpl.getValue() << " }\n";
|
||||
}
|
||||
|
||||
os << " };\n";
|
||||
}
|
||||
|
||||
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
|
||||
StringRef interfaceName = interface.getName();
|
||||
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
|
||||
|
@ -168,6 +191,9 @@ static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
|
|||
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
||||
interfaceName, interfaceName, interfaceTraitsName);
|
||||
|
||||
// Emit the derived trait for the interface.
|
||||
emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
|
||||
|
||||
// Insert the method declarations.
|
||||
for (auto &method : interface.getMethods()) {
|
||||
os << " " << method.getReturnType() << " ";
|
||||
|
|
Loading…
Reference in New Issue