[mlir] Add an interface for Cast-Like operations

A cast-like operation is one that converts from a set of input types to a set of output types. The arity of the inputs may be from 0-N, whereas the arity of the outputs may be anything from 1-N. Cast-like operations are removable in cases where they produce a "no-op", i.e when the input types and output types match 1-1.

Differential Revision: https://reviews.llvm.org/D94831
This commit is contained in:
River Riddle 2021-01-20 16:17:13 -08:00
parent 87a89549c4
commit 6ccf2d62b4
38 changed files with 388 additions and 208 deletions

View File

@ -182,26 +182,50 @@ to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
casts between two different shapes.
```tablegen
def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
def CastOp : Toy_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
NoSideEffect,
SameOperandsAndResultShape]
> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
must both be tensor types with the same element type. If both are ranked,
then shape is required to match. The operation is invalid if converting
to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
```
We can then override the necessary hook on the ToyInlinerInterface to insert
this for us when necessary:
Note that the definition of this cast operation adds a `CastOpInterface` to the
traits list. This interface provides several utilities for cast-like operation,
such as folding identity casts and verification. We hook into this interface by
providing a definition for the `areCastCompatible` method:
```c++
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
```
With a proper cast operation, we can now override the necessary hook on the
ToyInlinerInterface to insert it for us when necessary:
```c++
struct ToyInlinerInterface : public DialectInlinerInterface {

View File

@ -29,6 +29,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch4
PRIVATE
MLIRAnalysis
MLIRCastInterfaces
MLIRCallInterfaces
MLIRIR
MLIRParser

View File

@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
def CastOp : Toy_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
NoSideEffect,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
without changing any data elements. The source and destination types must
both be tensor types with the same element type. If both are ranked, then
shape is required to match. The operation is invalid if converting to a
mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",

View File

@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// GenericCallOp

View File

@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

View File

@ -33,6 +33,7 @@ target_link_libraries(toyc-ch5
${dialect_libs}
MLIRAnalysis
MLIRCallInterfaces
MLIRCastInterfaces
MLIRIR
MLIRParser
MLIRPass

View File

@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
def CastOp : Toy_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
NoSideEffect,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
without changing any data elements. The source and destination types must
both be tensor types with the same element type. If both are ranked, then
shape is required to match. The operation is invalid if converting to a
mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",

View File

@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// GenericCallOp

View File

@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

View File

@ -39,6 +39,7 @@ target_link_libraries(toyc-ch6
${conversion_libs}
MLIRAnalysis
MLIRCallInterfaces
MLIRCastInterfaces
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR

View File

@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
];
}
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
def CastOp : Toy_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
NoSideEffect,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
without changing any data elements. The source and destination types must
both be tensor types with the same element type. If both are ranked, then
shape is required to match. The operation is invalid if converting to a
mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",

View File

@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// GenericCallOp

View File

@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

View File

@ -39,6 +39,7 @@ target_link_libraries(toyc-ch7
${conversion_libs}
MLIRAnalysis
MLIRCallInterfaces
MLIRCastInterfaces
MLIRExecutionEngine
MLIRIR
MLIRParser

View File

@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -14,6 +14,7 @@
#define TOY_OPS
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "toy/ShapeInferenceInterface.td"
@ -115,25 +116,25 @@ def AddOp : Toy_Op<"add",
];
}
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
def CastOp : Toy_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
NoSideEffect,
SameOperandsAndResultShape
]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
without changing any data elements. The source and destination types must
both be tensor types with the same element type. If both are ranked, then
shape is required to match. The operation is invalid if converting to a
mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",

View File

@ -284,6 +284,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// GenericCallOp

View File

@ -23,11 +23,6 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }

View File

@ -19,6 +19,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"

View File

@ -17,6 +17,7 @@ include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
@ -45,9 +46,10 @@ class Std_Op<string mnemonic, list<OpTrait> traits = []> :
// Base class for standard cast operations. Requires single operand and result,
// but does not constrain them to specific types.
class CastOp<string mnemonic, list<OpTrait> traits = []> :
Std_Op<mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> {
Std_Op<mnemonic, traits # [
NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<CastOpInterface>
]> {
let results = (outs AnyType);
let builders = [
@ -62,9 +64,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
let hasFolder = 1;
// Cast operations are fully verified by its traits.
let verifier = ?;
}
// Base class for arithmetic cast operations.
@ -1643,14 +1645,6 @@ def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
The destination type must to be strictly wider than the source type.
Only scalars are currently supported.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@ -1663,14 +1657,6 @@ def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) signed integer value.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@ -1683,14 +1669,6 @@ def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) unsigned integer value.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@ -1705,14 +1683,6 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
If the value cannot be exactly represented, it is rounded using the default
rounding mode. Only scalars are currently supported.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@ -1849,12 +1819,6 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
sign-extended. If casting to a narrower integer, the value is truncated.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 1;
}
@ -2045,14 +2009,7 @@ def MemRefCastOp : CastOp<"memref_cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a memref_cast is always a memref.
Type getType() { return getResult().getType(); }
}];
let hasFolder = 1;
}
@ -2786,14 +2743,6 @@ def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
exactly represented, it is rounded using the default rounding mode. Scalars
and vector types are currently supported.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
@ -3628,14 +3577,6 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
value cannot be exactly represented, it is rounded using the default
rounding mode. Scalars and vector types are currently supported.
}];
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
}];
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//

View File

@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -10,6 +10,7 @@
#define TENSOR_OPS
include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -24,7 +25,9 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
// CastOp
//===----------------------------------------------------------------------===//
def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
def Tensor_CastOp : Tensor_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
]> {
let summary = "tensor cast operation";
let description = [{
Convert a tensor from one type to an equivalent type without changing any
@ -51,19 +54,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor.cast is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let verifier = ?;
}
//===----------------------------------------------------------------------===//

View File

@ -50,6 +50,34 @@ enum class DiagnosticSeverity {
/// A variant type that holds a single argument for a diagnostic.
class DiagnosticArgument {
public:
/// Note: The constructors below are only exposed due to problems accessing
/// constructors from type traits, they should not be used directly by users.
// Construct from an Attribute.
explicit DiagnosticArgument(Attribute attr);
// Construct from a floating point number.
explicit DiagnosticArgument(double val)
: kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
// Construct from a signed integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_signed<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(int64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
// Construct from an unsigned integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_unsigned<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(uint64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
// Construct from a string reference.
explicit DiagnosticArgument(StringRef val)
: kind(DiagnosticArgumentKind::String), stringVal(val) {}
// Construct from a Type.
explicit DiagnosticArgument(Type val);
/// Enum that represents the different kinds of diagnostic arguments
/// supported.
enum class DiagnosticArgumentKind {
@ -100,37 +128,6 @@ public:
private:
friend class Diagnostic;
// Construct from an Attribute.
explicit DiagnosticArgument(Attribute attr);
// Construct from a floating point number.
explicit DiagnosticArgument(double val)
: kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
// Construct from a signed integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_signed<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(int64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
// Construct from an unsigned integer.
template <typename T>
explicit DiagnosticArgument(
T val, typename std::enable_if<std::is_unsigned<T>::value &&
std::numeric_limits<T>::is_integer &&
sizeof(T) <= sizeof(uint64_t)>::type * = 0)
: kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
// Construct from a string reference.
explicit DiagnosticArgument(StringRef val)
: kind(DiagnosticArgumentKind::String), stringVal(val) {}
// Construct from a Type.
explicit DiagnosticArgument(Type val);
/// The kind of this argument.
DiagnosticArgumentKind kind;
@ -189,8 +186,10 @@ public:
/// Stream operator for inserting new diagnostic arguments.
template <typename Arg>
typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
Diagnostic &>::type
typename std::enable_if<
!std::is_convertible<Arg, StringRef>::value &&
std::is_constructible<DiagnosticArgument, Arg>::value,
Diagnostic &>::type
operator<<(Arg &&val) {
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
return *this;
@ -220,17 +219,17 @@ public:
}
/// Stream in a range.
template <typename T> Diagnostic &operator<<(iterator_range<T> range) {
return appendRange(range);
}
template <typename T> Diagnostic &operator<<(ArrayRef<T> range) {
template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
Diagnostic &>
operator<<(T &&range) {
return appendRange(range);
}
/// Append a range to the diagnostic. The default delimiter between elements
/// is ','.
template <typename T, template <typename> class Container>
Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
template <typename T>
Diagnostic &appendRange(const T &c, const char *delim = ", ") {
llvm::interleave(
c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
return *this;

View File

@ -1822,18 +1822,27 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
void printOneResultOp(Operation *op, OpAsmPrinter &p);
} // namespace impl
// These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated.
// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
/// Attempt to fold the given cast operation.
LogicalResult foldCastInterfaceOp(Operation *op,
ArrayRef<Attribute> attrOperands,
SmallVectorImpl<OpFoldResult> &foldResults);
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
// TODO: Create a CastOpInterface with a method areCastCompatible.
// Also, consider adding functionality to CastOpInterface to be able to perform
// the ChainedTensorCast canonicalization generically.
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
// when all uses have been updated. Also, consider adding functionality to
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
// generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);

View File

@ -1,4 +1,5 @@
add_mlir_interface(CallInterfaces)
add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)

View File

@ -0,0 +1,22 @@
//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the cast interfaces defined in
// `CastInterfaces.td`.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_CASTINTERFACES_H
#define MLIR_INTERFACES_CASTINTERFACES_H
#include "mlir/IR/OpDefinition.h"
/// Include the generated interface declarations.
#include "mlir/Interfaces/CastInterfaces.h.inc"
#endif // MLIR_INTERFACES_CASTINTERFACES_H

View File

@ -0,0 +1,51 @@
//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to define information
// related to cast-like operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_CASTINTERFACES
#define MLIR_INTERFACES_CASTINTERFACES
include "mlir/IR/OpBase.td"
def CastOpInterface : OpInterface<"CastOpInterface"> {
let description = [{
A cast-like operation is one that converts from a set of input types to a
set of output types. The arity of the inputs may be from 0-N, whereas the
arity of the outputs may be anything from 1-N. Cast-like operations are
trivially removable in cases where they produce an No-op, i.e when the
input types and output types match 1-1.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<[{
Returns true if the given set of input and result types are compatible
to cast using this cast operation.
}],
"bool", "areCastCompatible",
(ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs)
>,
];
let extraTraitClassDeclaration = [{
/// Attempt to fold the given cast operation.
static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return impl::foldCastInterfaceOp(op, operands, results);
}
}];
let verify = [{
return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible);
}];
}
#endif // MLIR_INTERFACES_CASTINTERFACES

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRShape
MLIRShapeOpsIncGen
LINK_LIBS PUBLIC
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRDialect
MLIRInferTypeOpInterface

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRStandard
LINK_LIBS PUBLIC
MLIRCallInterfaces
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIREDSC
MLIRIR

View File

@ -195,7 +195,8 @@ static LogicalResult foldMemRefCast(Operation *op) {
/// Returns 'true' if the vector types are cast compatible, and 'false'
/// otherwise.
static bool areVectorCastSimpleCompatible(
Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
Type a, Type b,
function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
if (auto va = a.dyn_cast<VectorType>())
if (auto vb = b.dyn_cast<VectorType>())
return va.getShape().equals(vb.getShape()) &&
@ -1746,7 +1747,10 @@ LogicalResult DmaWaitOp::verify() {
// FPExtOp
//===----------------------------------------------------------------------===//
bool FPExtOp::areCastCompatible(Type a, Type b) {
bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() < fb.getWidth();
@ -1757,7 +1761,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
// FPToSIOp
//===----------------------------------------------------------------------===//
bool FPToSIOp::areCastCompatible(Type a, Type b) {
bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@ -1767,7 +1774,10 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
// FPToUIOp
//===----------------------------------------------------------------------===//
bool FPToUIOp::areCastCompatible(Type a, Type b) {
bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@ -1777,7 +1787,10 @@ bool FPToUIOp::areCastCompatible(Type a, Type b) {
// FPTruncOp
//===----------------------------------------------------------------------===//
bool FPTruncOp::areCastCompatible(Type a, Type b) {
bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() > fb.getWidth();
@ -1889,7 +1902,10 @@ GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
// Index cast is applicable from index to integer and backwards.
bool IndexCastOp::areCastCompatible(Type a, Type b) {
bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
auto aShaped = a.cast<ShapedType>();
auto bShaped = b.cast<ShapedType>();
@ -1965,7 +1981,10 @@ void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
Value MemRefCastOp::getViewSource() { return source(); }
bool MemRefCastOp::areCastCompatible(Type a, Type b) {
bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
auto aT = a.dyn_cast<MemRefType>();
auto bT = b.dyn_cast<MemRefType>();
@ -2036,8 +2055,6 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
}
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
if (Value folded = impl::foldCastOp(*this))
return folded;
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
@ -2633,7 +2650,10 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
// sitofp is applicable from integer types to float types.
bool SIToFPOp::areCastCompatible(Type a, Type b) {
bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
@ -2715,7 +2735,10 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
// uitofp is applicable from integer types to float types.
bool UIToFPOp::areCastCompatible(Type a, Type b) {
bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTensor
Core
LINK_LIBS PUBLIC
MLIRCastInterfaces
MLIRIR
MLIRSideEffectInterfaces
MLIRSupport

View File

@ -73,7 +73,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
return true;
}
bool CastOp::areCastCompatible(Type a, Type b) {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
@ -85,10 +88,6 @@ bool CastOp::areCastCompatible(Type a, Type b) {
return succeeded(verifyCompatibleShape(aT, bT));
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {

View File

@ -1208,6 +1208,48 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
// CastOp implementation
//===----------------------------------------------------------------------===//
/// Attempt to fold the given cast operation.
LogicalResult
impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
SmallVectorImpl<OpFoldResult> &foldResults) {
OperandRange operands = op->getOperands();
if (operands.empty())
return failure();
ResultRange results = op->getResults();
// Check for the case where the input and output types match 1-1.
if (operands.getTypes() == results.getTypes()) {
foldResults.append(operands.begin(), operands.end());
return success();
}
return failure();
}
/// Attempt to verify the given cast operation.
LogicalResult impl::verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) {
auto resultTypes = op->getResultTypes();
if (llvm::empty(resultTypes))
return op->emitOpError()
<< "expected at least one result for cast operation";
auto operandTypes = op->getOperandTypes();
if (!areCastCompatible(operandTypes, resultTypes)) {
InFlightDiagnostic diag = op->emitOpError("operand type");
if (llvm::empty(operandTypes))
diag << "s []";
else if (llvm::size(operandTypes) == 1)
diag << " " << *operandTypes.begin();
else
diag << "s " << operandTypes;
return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
<< resultTypes << " are cast incompatible";
}
return success();
}
void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType) {
result.addOperands(source);

View File

@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CallInterfaces.cpp
CastInterfaces.cpp
ControlFlowInterfaces.cpp
CopyOpInterface.cpp
DerivedAttributeOpInterface.cpp
@ -27,6 +28,7 @@ endfunction(add_mlir_interface_library)
add_mlir_interface_library(CallInterfaces)
add_mlir_interface_library(CastInterfaces)
add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DerivedAttributeOpInterface)

View File

@ -0,0 +1,17 @@
//===- CastInterfaces.cpp -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/CastInterfaces.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Table-generated class definitions
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/CastInterfaces.cpp.inc"