forked from OSchip/llvm-project
[mlir] Add an interface for Cast-Like operations
A cast-like operation is one that converts from a set of input types to a set of output types. The arity of the inputs may be from 0-N, whereas the arity of the outputs may be anything from 1-N. Cast-like operations are removable in cases where they produce a "no-op", i.e when the input types and output types match 1-1. Differential Revision: https://reviews.llvm.org/D94831
This commit is contained in:
parent
87a89549c4
commit
6ccf2d62b4
|
@ -182,26 +182,50 @@ to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
|
|||
casts between two different shapes.
|
||||
|
||||
```tablegen
|
||||
def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
def CastOp : Toy_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape]
|
||||
> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
must both be tensor types with the same element type. If both are ranked,
|
||||
then shape is required to match. The operation is invalid if converting
|
||||
to a mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
```
|
||||
|
||||
We can then override the necessary hook on the ToyInlinerInterface to insert
|
||||
this for us when necessary:
|
||||
Note that the definition of this cast operation adds a `CastOpInterface` to the
|
||||
traits list. This interface provides several utilities for cast-like operation,
|
||||
such as folding identity casts and verification. We hook into this interface by
|
||||
providing a definition for the `areCastCompatible` method:
|
||||
|
||||
```c++
|
||||
/// Returns true if the given set of input and result types are compatible with
|
||||
/// this cast operation. This is required by the `CastOpInterface` to verify
|
||||
/// this operation and provide other additional utilities.
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
With a proper cast operation, we can now override the necessary hook on the
|
||||
ToyInlinerInterface to insert it for us when necessary:
|
||||
|
||||
```c++
|
||||
struct ToyInlinerInterface : public DialectInlinerInterface {
|
||||
|
|
|
@ -29,6 +29,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
|
|||
target_link_libraries(toyc-ch4
|
||||
PRIVATE
|
||||
MLIRAnalysis
|
||||
MLIRCastInterfaces
|
||||
MLIRCallInterfaces
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#define TOY_OPS
|
||||
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
|
||||
|
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
|
|||
];
|
||||
}
|
||||
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
def CastOp : Toy_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
without changing any data elements. The source and destination types must
|
||||
both be tensor types with the same element type. If both are ranked, then
|
||||
shape is required to match. The operation is invalid if converting to a
|
||||
mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
|
|
|
@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
|||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
/// Returns true if the given set of input and result types are compatible with
|
||||
/// this cast operation. This is required by the `CastOpInterface` to verify
|
||||
/// this operation and provide other additional utilities.
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
||||
|
|
|
@ -23,11 +23,6 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> x
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
|
|
|
@ -33,6 +33,7 @@ target_link_libraries(toyc-ch5
|
|||
${dialect_libs}
|
||||
MLIRAnalysis
|
||||
MLIRCallInterfaces
|
||||
MLIRCastInterfaces
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#define TOY_OPS
|
||||
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
|
||||
|
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
|
|||
];
|
||||
}
|
||||
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
def CastOp : Toy_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
without changing any data elements. The source and destination types must
|
||||
both be tensor types with the same element type. If both are ranked, then
|
||||
shape is required to match. The operation is invalid if converting to a
|
||||
mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
|
|
|
@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
|||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
/// Returns true if the given set of input and result types are compatible with
|
||||
/// this cast operation. This is required by the `CastOpInterface` to verify
|
||||
/// this operation and provide other additional utilities.
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
||||
|
|
|
@ -23,11 +23,6 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> x
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
|
|
|
@ -39,6 +39,7 @@ target_link_libraries(toyc-ch6
|
|||
${conversion_libs}
|
||||
MLIRAnalysis
|
||||
MLIRCallInterfaces
|
||||
MLIRCastInterfaces
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#define TOY_OPS
|
||||
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
|
||||
|
@ -102,25 +103,25 @@ def AddOp : Toy_Op<"add",
|
|||
];
|
||||
}
|
||||
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
def CastOp : Toy_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
without changing any data elements. The source and destination types must
|
||||
both be tensor types with the same element type. If both are ranked, then
|
||||
shape is required to match. The operation is invalid if converting to a
|
||||
mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
|
|
|
@ -232,6 +232,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
|||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
/// Returns true if the given set of input and result types are compatible with
|
||||
/// this cast operation. This is required by the `CastOpInterface` to verify
|
||||
/// this operation and provide other additional utilities.
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
||||
|
|
|
@ -23,11 +23,6 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> x
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
|
|
|
@ -39,6 +39,7 @@ target_link_libraries(toyc-ch7
|
|||
${conversion_libs}
|
||||
MLIRAnalysis
|
||||
MLIRCallInterfaces
|
||||
MLIRCastInterfaces
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#define TOY_OPS
|
||||
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
|
||||
|
@ -115,25 +116,25 @@ def AddOp : Toy_Op<"add",
|
|||
];
|
||||
}
|
||||
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
def CastOp : Toy_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
without changing any data elements. The source and destination types must
|
||||
both be tensor types with the same element type. If both are ranked, then
|
||||
shape is required to match. The operation is invalid if converting to a
|
||||
mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
|
|
|
@ -284,6 +284,21 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
|||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
/// Returns true if the given set of input and result types are compatible with
|
||||
/// this cast operation. This is required by the `CastOpInterface` to verify
|
||||
/// this operation and provide other additional utilities.
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
||||
|
|
|
@ -23,11 +23,6 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// Fold constants.
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
|
|
@ -17,6 +17,7 @@ include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
|
|||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/VectorInterfaces.td"
|
||||
|
@ -45,9 +46,10 @@ class Std_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// Base class for standard cast operations. Requires single operand and result,
|
||||
// but does not constrain them to specific types.
|
||||
class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Std_Op<mnemonic,
|
||||
!listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> {
|
||||
|
||||
Std_Op<mnemonic, traits # [
|
||||
NoSideEffect, SameOperandsAndResultShape,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>
|
||||
]> {
|
||||
let results = (outs AnyType);
|
||||
|
||||
let builders = [
|
||||
|
@ -62,9 +64,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
let printer = [{
|
||||
return printStandardCastOp(this->getOperation(), p);
|
||||
}];
|
||||
let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
// Cast operations are fully verified by its traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
// Base class for arithmetic cast operations.
|
||||
|
@ -1643,14 +1645,6 @@ def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
|
|||
The destination type must to be strictly wider than the source type.
|
||||
Only scalars are currently supported.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1663,14 +1657,6 @@ def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
|
|||
Cast from a value interpreted as floating-point to the nearest (rounding
|
||||
towards zero) signed integer value.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1683,14 +1669,6 @@ def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
|
|||
Cast from a value interpreted as floating-point to the nearest (rounding
|
||||
towards zero) unsigned integer value.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1705,14 +1683,6 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
|
|||
If the value cannot be exactly represented, it is rounded using the default
|
||||
rounding mode. Only scalars are currently supported.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1849,12 +1819,6 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
|
|||
sign-extended. If casting to a narrower integer, the value is truncated.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
@ -2045,14 +2009,7 @@ def MemRefCastOp : CastOp<"memref_cast", [
|
|||
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
Type getType() { return getResult().getType(); }
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -2786,14 +2743,6 @@ def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
|
|||
exactly represented, it is rounded using the default rounding mode. Scalars
|
||||
and vector types are currently supported.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3628,14 +3577,6 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
|
|||
value cannot be exactly represented, it is rounded using the default
|
||||
rounding mode. Scalars and vector types are currently supported.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define TENSOR_OPS
|
||||
|
||||
include "mlir/Dialect/Tensor/IR/TensorBase.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
|
@ -24,7 +25,9 @@ class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
|
|||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
|
||||
def Tensor_CastOp : Tensor_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
|
||||
]> {
|
||||
let summary = "tensor cast operation";
|
||||
let description = [{
|
||||
Convert a tensor from one type to an equivalent type without changing any
|
||||
|
@ -51,19 +54,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> {
|
|||
let arguments = (ins AnyTensor:$source);
|
||||
let results = (outs AnyTensor:$dest);
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor.cast is always a tensor.
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -50,6 +50,34 @@ enum class DiagnosticSeverity {
|
|||
/// A variant type that holds a single argument for a diagnostic.
|
||||
class DiagnosticArgument {
|
||||
public:
|
||||
/// Note: The constructors below are only exposed due to problems accessing
|
||||
/// constructors from type traits, they should not be used directly by users.
|
||||
// Construct from an Attribute.
|
||||
explicit DiagnosticArgument(Attribute attr);
|
||||
// Construct from a floating point number.
|
||||
explicit DiagnosticArgument(double val)
|
||||
: kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
|
||||
explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
|
||||
// Construct from a signed integer.
|
||||
template <typename T>
|
||||
explicit DiagnosticArgument(
|
||||
T val, typename std::enable_if<std::is_signed<T>::value &&
|
||||
std::numeric_limits<T>::is_integer &&
|
||||
sizeof(T) <= sizeof(int64_t)>::type * = 0)
|
||||
: kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
|
||||
// Construct from an unsigned integer.
|
||||
template <typename T>
|
||||
explicit DiagnosticArgument(
|
||||
T val, typename std::enable_if<std::is_unsigned<T>::value &&
|
||||
std::numeric_limits<T>::is_integer &&
|
||||
sizeof(T) <= sizeof(uint64_t)>::type * = 0)
|
||||
: kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
|
||||
// Construct from a string reference.
|
||||
explicit DiagnosticArgument(StringRef val)
|
||||
: kind(DiagnosticArgumentKind::String), stringVal(val) {}
|
||||
// Construct from a Type.
|
||||
explicit DiagnosticArgument(Type val);
|
||||
|
||||
/// Enum that represents the different kinds of diagnostic arguments
|
||||
/// supported.
|
||||
enum class DiagnosticArgumentKind {
|
||||
|
@ -100,37 +128,6 @@ public:
|
|||
private:
|
||||
friend class Diagnostic;
|
||||
|
||||
// Construct from an Attribute.
|
||||
explicit DiagnosticArgument(Attribute attr);
|
||||
|
||||
// Construct from a floating point number.
|
||||
explicit DiagnosticArgument(double val)
|
||||
: kind(DiagnosticArgumentKind::Double), doubleVal(val) {}
|
||||
explicit DiagnosticArgument(float val) : DiagnosticArgument(double(val)) {}
|
||||
|
||||
// Construct from a signed integer.
|
||||
template <typename T>
|
||||
explicit DiagnosticArgument(
|
||||
T val, typename std::enable_if<std::is_signed<T>::value &&
|
||||
std::numeric_limits<T>::is_integer &&
|
||||
sizeof(T) <= sizeof(int64_t)>::type * = 0)
|
||||
: kind(DiagnosticArgumentKind::Integer), opaqueVal(int64_t(val)) {}
|
||||
|
||||
// Construct from an unsigned integer.
|
||||
template <typename T>
|
||||
explicit DiagnosticArgument(
|
||||
T val, typename std::enable_if<std::is_unsigned<T>::value &&
|
||||
std::numeric_limits<T>::is_integer &&
|
||||
sizeof(T) <= sizeof(uint64_t)>::type * = 0)
|
||||
: kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {}
|
||||
|
||||
// Construct from a string reference.
|
||||
explicit DiagnosticArgument(StringRef val)
|
||||
: kind(DiagnosticArgumentKind::String), stringVal(val) {}
|
||||
|
||||
// Construct from a Type.
|
||||
explicit DiagnosticArgument(Type val);
|
||||
|
||||
/// The kind of this argument.
|
||||
DiagnosticArgumentKind kind;
|
||||
|
||||
|
@ -189,8 +186,10 @@ public:
|
|||
|
||||
/// Stream operator for inserting new diagnostic arguments.
|
||||
template <typename Arg>
|
||||
typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
|
||||
Diagnostic &>::type
|
||||
typename std::enable_if<
|
||||
!std::is_convertible<Arg, StringRef>::value &&
|
||||
std::is_constructible<DiagnosticArgument, Arg>::value,
|
||||
Diagnostic &>::type
|
||||
operator<<(Arg &&val) {
|
||||
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
|
||||
return *this;
|
||||
|
@ -220,17 +219,17 @@ public:
|
|||
}
|
||||
|
||||
/// Stream in a range.
|
||||
template <typename T> Diagnostic &operator<<(iterator_range<T> range) {
|
||||
return appendRange(range);
|
||||
}
|
||||
template <typename T> Diagnostic &operator<<(ArrayRef<T> range) {
|
||||
template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
|
||||
std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
|
||||
Diagnostic &>
|
||||
operator<<(T &&range) {
|
||||
return appendRange(range);
|
||||
}
|
||||
|
||||
/// Append a range to the diagnostic. The default delimiter between elements
|
||||
/// is ','.
|
||||
template <typename T, template <typename> class Container>
|
||||
Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
|
||||
template <typename T>
|
||||
Diagnostic &appendRange(const T &c, const char *delim = ", ") {
|
||||
llvm::interleave(
|
||||
c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
|
||||
return *this;
|
||||
|
|
|
@ -1822,18 +1822,27 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
|||
void printOneResultOp(Operation *op, OpAsmPrinter &p);
|
||||
} // namespace impl
|
||||
|
||||
// These functions are out-of-line implementations of the methods in CastOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
// These functions are out-of-line implementations of the methods in
|
||||
// CastOpInterface, which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
/// Attempt to fold the given cast operation.
|
||||
LogicalResult foldCastInterfaceOp(Operation *op,
|
||||
ArrayRef<Attribute> attrOperands,
|
||||
SmallVectorImpl<OpFoldResult> &foldResults);
|
||||
/// Attempt to verify the given cast operation.
|
||||
LogicalResult verifyCastInterfaceOp(
|
||||
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
|
||||
|
||||
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
|
||||
// need for them, but some older ODS code in `std` still depends on them).
|
||||
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
|
||||
Type destType);
|
||||
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
|
||||
void printCastOp(Operation *op, OpAsmPrinter &p);
|
||||
// TODO: Create a CastOpInterface with a method areCastCompatible.
|
||||
// Also, consider adding functionality to CastOpInterface to be able to perform
|
||||
// the ChainedTensorCast canonicalization generically.
|
||||
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
|
||||
// when all uses have been updated. Also, consider adding functionality to
|
||||
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
|
||||
// generically.
|
||||
Value foldCastOp(Operation *op);
|
||||
LogicalResult verifyCastOp(Operation *op,
|
||||
function_ref<bool(Type, Type)> areCastCompatible);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_interface(CallInterfaces)
|
||||
add_mlir_interface(CastInterfaces)
|
||||
add_mlir_interface(ControlFlowInterfaces)
|
||||
add_mlir_interface(CopyOpInterface)
|
||||
add_mlir_interface(DerivedAttributeOpInterface)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
//===- CastInterfaces.h - Cast Interfaces for MLIR --------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains the definitions of the cast interfaces defined in
|
||||
// `CastInterfaces.td`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_CASTINTERFACES_H
|
||||
#define MLIR_INTERFACES_CASTINTERFACES_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Interfaces/CastInterfaces.h.inc"
|
||||
|
||||
#endif // MLIR_INTERFACES_CASTINTERFACES_H
|
|
@ -0,0 +1,51 @@
|
|||
//===- CastInterfaces.td - Cast Interfaces for ops ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains a set of interfaces that can be used to define information
|
||||
// related to cast-like operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_CASTINTERFACES
|
||||
#define MLIR_INTERFACES_CASTINTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def CastOpInterface : OpInterface<"CastOpInterface"> {
|
||||
let description = [{
|
||||
A cast-like operation is one that converts from a set of input types to a
|
||||
set of output types. The arity of the inputs may be from 0-N, whereas the
|
||||
arity of the outputs may be anything from 1-N. Cast-like operations are
|
||||
trivially removable in cases where they produce an No-op, i.e when the
|
||||
input types and output types match 1-1.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
StaticInterfaceMethod<[{
|
||||
Returns true if the given set of input and result types are compatible
|
||||
to cast using this cast operation.
|
||||
}],
|
||||
"bool", "areCastCompatible",
|
||||
(ins "mlir::TypeRange":$inputs, "mlir::TypeRange":$outputs)
|
||||
>,
|
||||
];
|
||||
|
||||
let extraTraitClassDeclaration = [{
|
||||
/// Attempt to fold the given cast operation.
|
||||
static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return impl::foldCastInterfaceOp(op, operands, results);
|
||||
}
|
||||
}];
|
||||
let verify = [{
|
||||
return impl::verifyCastInterfaceOp($_op, ConcreteOp::areCastCompatible);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_INTERFACES_CASTINTERFACES
|
|
@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRShape
|
|||
MLIRShapeOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCastInterfaces
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRStandard
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCallInterfaces
|
||||
MLIRCastInterfaces
|
||||
MLIRControlFlowInterfaces
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
|
|
|
@ -195,7 +195,8 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
/// Returns 'true' if the vector types are cast compatible, and 'false'
|
||||
/// otherwise.
|
||||
static bool areVectorCastSimpleCompatible(
|
||||
Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
|
||||
Type a, Type b,
|
||||
function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
|
||||
if (auto va = a.dyn_cast<VectorType>())
|
||||
if (auto vb = b.dyn_cast<VectorType>())
|
||||
return va.getShape().equals(vb.getShape()) &&
|
||||
|
@ -1746,7 +1747,10 @@ LogicalResult DmaWaitOp::verify() {
|
|||
// FPExtOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool FPExtOp::areCastCompatible(Type a, Type b) {
|
||||
bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (auto fa = a.dyn_cast<FloatType>())
|
||||
if (auto fb = b.dyn_cast<FloatType>())
|
||||
return fa.getWidth() < fb.getWidth();
|
||||
|
@ -1757,7 +1761,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
|
|||
// FPToSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool FPToSIOp::areCastCompatible(Type a, Type b) {
|
||||
bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (a.isa<FloatType>() && b.isSignlessInteger())
|
||||
return true;
|
||||
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
|
||||
|
@ -1767,7 +1774,10 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
|
|||
// FPToUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool FPToUIOp::areCastCompatible(Type a, Type b) {
|
||||
bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (a.isa<FloatType>() && b.isSignlessInteger())
|
||||
return true;
|
||||
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
|
||||
|
@ -1777,7 +1787,10 @@ bool FPToUIOp::areCastCompatible(Type a, Type b) {
|
|||
// FPTruncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool FPTruncOp::areCastCompatible(Type a, Type b) {
|
||||
bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (auto fa = a.dyn_cast<FloatType>())
|
||||
if (auto fb = b.dyn_cast<FloatType>())
|
||||
return fa.getWidth() > fb.getWidth();
|
||||
|
@ -1889,7 +1902,10 @@ GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Index cast is applicable from index to integer and backwards.
|
||||
bool IndexCastOp::areCastCompatible(Type a, Type b) {
|
||||
bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
|
||||
auto aShaped = a.cast<ShapedType>();
|
||||
auto bShaped = b.cast<ShapedType>();
|
||||
|
@ -1965,7 +1981,10 @@ void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
|
||||
Value MemRefCastOp::getViewSource() { return source(); }
|
||||
|
||||
bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
||||
bool MemRefCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
auto aT = a.dyn_cast<MemRefType>();
|
||||
auto bT = b.dyn_cast<MemRefType>();
|
||||
|
||||
|
@ -2036,8 +2055,6 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
|||
}
|
||||
|
||||
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (Value folded = impl::foldCastOp(*this))
|
||||
return folded;
|
||||
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
|
||||
}
|
||||
|
||||
|
@ -2633,7 +2650,10 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// sitofp is applicable from integer types to float types.
|
||||
bool SIToFPOp::areCastCompatible(Type a, Type b) {
|
||||
bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (a.isSignlessInteger() && b.isa<FloatType>())
|
||||
return true;
|
||||
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
|
||||
|
@ -2715,7 +2735,10 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// uitofp is applicable from integer types to float types.
|
||||
bool UIToFPOp::areCastCompatible(Type a, Type b) {
|
||||
bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
if (a.isSignlessInteger() && b.isa<FloatType>())
|
||||
return true;
|
||||
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTensor
|
|||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCastInterfaces
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
|
|
|
@ -73,7 +73,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CastOp::areCastCompatible(Type a, Type b) {
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
Type a = inputs.front(), b = outputs.front();
|
||||
auto aT = a.dyn_cast<TensorType>();
|
||||
auto bT = b.dyn_cast<TensorType>();
|
||||
if (!aT || !bT)
|
||||
|
@ -85,10 +88,6 @@ bool CastOp::areCastCompatible(Type a, Type b) {
|
|||
return succeeded(verifyCompatibleShape(aT, bT));
|
||||
}
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// Compute a TensorType that has the joined shape knowledge of the two
|
||||
/// given TensorTypes. The element types need to match.
|
||||
static TensorType joinShapes(TensorType one, TensorType two) {
|
||||
|
|
|
@ -1208,6 +1208,48 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
|
|||
// CastOp implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Attempt to fold the given cast operation.
|
||||
LogicalResult
|
||||
impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
|
||||
SmallVectorImpl<OpFoldResult> &foldResults) {
|
||||
OperandRange operands = op->getOperands();
|
||||
if (operands.empty())
|
||||
return failure();
|
||||
ResultRange results = op->getResults();
|
||||
|
||||
// Check for the case where the input and output types match 1-1.
|
||||
if (operands.getTypes() == results.getTypes()) {
|
||||
foldResults.append(operands.begin(), operands.end());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Attempt to verify the given cast operation.
|
||||
LogicalResult impl::verifyCastInterfaceOp(
|
||||
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) {
|
||||
auto resultTypes = op->getResultTypes();
|
||||
if (llvm::empty(resultTypes))
|
||||
return op->emitOpError()
|
||||
<< "expected at least one result for cast operation";
|
||||
|
||||
auto operandTypes = op->getOperandTypes();
|
||||
if (!areCastCompatible(operandTypes, resultTypes)) {
|
||||
InFlightDiagnostic diag = op->emitOpError("operand type");
|
||||
if (llvm::empty(operandTypes))
|
||||
diag << "s []";
|
||||
else if (llvm::size(operandTypes) == 1)
|
||||
diag << " " << *operandTypes.begin();
|
||||
else
|
||||
diag << "s " << operandTypes;
|
||||
return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
|
||||
<< resultTypes << " are cast incompatible";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
|
||||
Type destType) {
|
||||
result.addOperands(source);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
CallInterfaces.cpp
|
||||
CastInterfaces.cpp
|
||||
ControlFlowInterfaces.cpp
|
||||
CopyOpInterface.cpp
|
||||
DerivedAttributeOpInterface.cpp
|
||||
|
@ -27,6 +28,7 @@ endfunction(add_mlir_interface_library)
|
|||
|
||||
|
||||
add_mlir_interface_library(CallInterfaces)
|
||||
add_mlir_interface_library(CastInterfaces)
|
||||
add_mlir_interface_library(ControlFlowInterfaces)
|
||||
add_mlir_interface_library(CopyOpInterface)
|
||||
add_mlir_interface_library(DerivedAttributeOpInterface)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
//===- CastInterfaces.cpp -------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Table-generated class definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/CastInterfaces.cpp.inc"
|
Loading…
Reference in New Issue