forked from OSchip/llvm-project
[ODS] Make the getType() method on a OneResult instruction return a specific type.
Implement Bug 46698, making ODS synthesize a getType() method that returns a specific C++ class for OneResult methods where we know that class. This eliminates a common source of casts in things like: myOp.getType().cast<FIRRTLType>().getPassive() because we know that myOp always returns a FIRRTLType. This also encourages op authors to type their results more tightly (which is also good for verification). I chose to implement this by splitting the OneResult trait into itself plus a OneTypedResult trait, given that many things are using `hasTrait<OneResult>` to conditionalize various logic. While this changes makes many many ops get more specific getType() results, it is generally drop-in compatible with the previous behavior because 'x.cast<T>()' is allowed when x is already known to be a T. The one exception to this is that we need declarations of the types used by ops, which is why a couple headers needed additional #includes. I updated a few things in tree to remove the now-redundant `.cast<>`'s, but there are probably many more than can be removed. Differential Revision: https://reviews.llvm.org/D93790
This commit is contained in:
parent
8791949f55
commit
9eb3e564d3
|
@ -210,7 +210,9 @@ class ConstantOp : public mlir::Op<ConstantOp,
|
|||
/// The ConstantOp takes no inputs.
|
||||
mlir::OpTrait::ZeroOperands,
|
||||
/// The ConstantOp returns a single result.
|
||||
mlir::OpTrait::OneResult> {
|
||||
mlir::OpTrait::OneResult,
|
||||
/// The result of getType is `Type`.
|
||||
mlir::OpTraits::OneTypedResult<Type>::Impl> {
|
||||
|
||||
public:
|
||||
/// Inherit the constructors from the base Op class.
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#ifndef STANDALONE_STANDALONEOPS_H
|
||||
#define STANDALONE_STANDALONEOPS_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_
|
||||
#define MLIR_DIALECT_AVX512_AVX512DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
|
||||
#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -175,8 +175,12 @@ class Constraint<Pred pred, string desc = ""> {
|
|||
// are considered as uncategorized constraints.
|
||||
|
||||
// Subclass for constraints on a type.
|
||||
class TypeConstraint<Pred predicate, string description = ""> :
|
||||
Constraint<predicate, description>;
|
||||
class TypeConstraint<Pred predicate, string description = "",
|
||||
string cppClassNameParam = "::mlir::Type"> :
|
||||
Constraint<predicate, description> {
|
||||
// The name of the C++ Type class if known, or Type if not.
|
||||
string cppClassName = cppClassNameParam;
|
||||
}
|
||||
|
||||
// Subclass for constraints on an attribute.
|
||||
class AttrConstraint<Pred predicate, string description = ""> :
|
||||
|
@ -285,8 +289,9 @@ class Dialect {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A type, carries type constraints.
|
||||
class Type<Pred condition, string descr = ""> :
|
||||
TypeConstraint<condition, descr> {
|
||||
class Type<Pred condition, string descr = "",
|
||||
string cppClassName = "::mlir::Type"> :
|
||||
TypeConstraint<condition, descr, cppClassName> {
|
||||
string typeDescription = "";
|
||||
string builderCall = "";
|
||||
}
|
||||
|
@ -299,8 +304,9 @@ class TypeAlias<Type t, string description = t.description> :
|
|||
}
|
||||
|
||||
// A type of a specific dialect.
|
||||
class DialectType<Dialect d, Pred condition, string descr = ""> :
|
||||
Type<condition, descr> {
|
||||
class DialectType<Dialect d, Pred condition, string descr = "",
|
||||
string cppClassName = "::mlir::Type"> :
|
||||
Type<condition, descr, cppClassName> {
|
||||
Dialect dialect = d;
|
||||
}
|
||||
|
||||
|
@ -331,11 +337,13 @@ class BuildableType<code builder> {
|
|||
def AnyType : Type<CPred<"true">, "any type">;
|
||||
|
||||
// None type
|
||||
def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type">,
|
||||
def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
|
||||
"::mlir::NoneType">,
|
||||
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
|
||||
|
||||
// Any type from the given list
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description = "",
|
||||
string cppClassName = "::mlir::Type"> : Type<
|
||||
// Satisfy any of the allowed type's condition
|
||||
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
|
||||
!if(!eq(description, ""),
|
||||
|
@ -345,7 +353,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
|||
// Integer types.
|
||||
|
||||
// Any integer type irrespective of its width and signedness semantics.
|
||||
def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer">;
|
||||
def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer",
|
||||
"::mlir::IntegerType">;
|
||||
|
||||
// Any integer type (regardless of signedness semantics) of a specific width.
|
||||
class AnyI<int width>
|
||||
|
@ -355,7 +364,8 @@ class AnyI<int width>
|
|||
|
||||
class AnyIntOfWidths<list<int> widths> :
|
||||
AnyTypeOf<!foreach(w, widths, AnyI<w>),
|
||||
StrJoinInt<widths, "/">.result # "-bit integer">;
|
||||
StrJoinInt<widths, "/">.result # "-bit integer",
|
||||
"::mlir::IntegerType">;
|
||||
|
||||
def AnyI1 : AnyI<1>;
|
||||
def AnyI8 : AnyI<8>;
|
||||
|
@ -365,12 +375,13 @@ def AnyI64 : AnyI<64>;
|
|||
|
||||
// Any signless integer type irrespective of its width.
|
||||
def AnySignlessInteger : Type<
|
||||
CPred<"$_self.isSignlessInteger()">, "signless integer">;
|
||||
CPred<"$_self.isSignlessInteger()">, "signless integer",
|
||||
"::mlir::IntegerType">;
|
||||
|
||||
// Signless integer type of a specific width.
|
||||
class I<int width>
|
||||
: Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
|
||||
width # "-bit signless integer">,
|
||||
width # "-bit signless integer", "::mlir::IntegerType">,
|
||||
BuildableType<"$_builder.getIntegerType(" # width # ")"> {
|
||||
int bitwidth = width;
|
||||
}
|
||||
|
@ -392,7 +403,7 @@ def AnySignedInteger : Type<
|
|||
// Signed integer type of a specific width.
|
||||
class SI<int width>
|
||||
: Type<CPred<"$_self.isSignedInteger(" # width # ")">,
|
||||
width # "-bit signed integer">,
|
||||
width # "-bit signed integer", "::mlir::IntegerType">,
|
||||
BuildableType<
|
||||
"$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
|
||||
int bitwidth = width;
|
||||
|
@ -415,7 +426,7 @@ def AnyUnsignedInteger : Type<
|
|||
// Unsigned integer type of a specific width.
|
||||
class UI<int width>
|
||||
: Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
|
||||
width # "-bit unsigned integer">,
|
||||
width # "-bit unsigned integer", "::mlir::IntegerType">,
|
||||
BuildableType<
|
||||
"$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
|
||||
int bitwidth = width;
|
||||
|
@ -432,18 +443,20 @@ def UI32 : UI<32>;
|
|||
def UI64 : UI<64>;
|
||||
|
||||
// Index type.
|
||||
def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index">,
|
||||
def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
|
||||
"::mlir::IndexType">,
|
||||
BuildableType<"$_builder.getIndexType()">;
|
||||
|
||||
// Floating point types.
|
||||
|
||||
// Any float type irrespective of its width.
|
||||
def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point">;
|
||||
def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point",
|
||||
"::mlir::FloatType">;
|
||||
|
||||
// Float type of a specific width.
|
||||
class F<int width>
|
||||
: Type<CPred<"$_self.isF" # width # "()">,
|
||||
width # "-bit float">,
|
||||
width # "-bit float", "::mlir::FloatType">,
|
||||
BuildableType<"$_builder.getF" # width # "Type()"> {
|
||||
int bitwidth = width;
|
||||
}
|
||||
|
@ -465,16 +478,17 @@ class Complex<Type type>
|
|||
SubstLeaves<"$_self",
|
||||
"$_self.cast<::mlir::ComplexType>().getElementType()",
|
||||
type.predicate>]>,
|
||||
"complex type with " # type.description # " elements"> {
|
||||
"complex type with " # type.description # " elements",
|
||||
"::mlir::ComplexType"> {
|
||||
Type elementType = type;
|
||||
}
|
||||
|
||||
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
|
||||
"complex-type">;
|
||||
"complex-type", "::mlir::ComplexType">;
|
||||
|
||||
class OpaqueType<string dialect, string name, string description>
|
||||
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
|
||||
description>,
|
||||
description, "::mlir::OpaqueType">,
|
||||
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
|
||||
"$_builder.getIdentifier(\"" # dialect # "\"), \""
|
||||
# name # "\")">;
|
||||
|
@ -483,17 +497,17 @@ class OpaqueType<string dialect, string name, string description>
|
|||
|
||||
// Any function type.
|
||||
def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">,
|
||||
"function type">;
|
||||
"function type", "::mlir::FunctionType">;
|
||||
|
||||
// A container type is a type that has another type embedded within it.
|
||||
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
||||
string descr> :
|
||||
string descr, string cppClassName = "::mlir::Type"> :
|
||||
// First, check the container predicate. Then, substitute the extracted
|
||||
// element into the element type checker.
|
||||
Type<And<[containerPred,
|
||||
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
|
||||
etype.predicate>]>,
|
||||
descr # " of " # etype.description # " values"> {
|
||||
descr # " of " # etype.description # " values", cppClassName> {
|
||||
// The type of elements in the container.
|
||||
Type elementType = etype;
|
||||
|
||||
|
@ -502,9 +516,11 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
}
|
||||
|
||||
class ShapedContainerType<list<Type> allowedTypes,
|
||||
Pred containerPred, string descr> :
|
||||
Pred containerPred, string descr,
|
||||
string cppClassName = "::mlir::Type"> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
|
||||
"$_self.cast<::mlir::ShapedType>().getElementType()", descr>;
|
||||
"$_self.cast<::mlir::ShapedType>().getElementType()", descr,
|
||||
cppClassName>;
|
||||
|
||||
// Whether a shaped type is ranked.
|
||||
def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
|
||||
|
@ -520,7 +536,8 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
|
|||
// Vector types.
|
||||
|
||||
class VectorOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
|
||||
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Whether the number of elements of a vector is from the given
|
||||
// `allowedRanks` list
|
||||
|
@ -534,7 +551,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
|
|||
// Any vector where the rank is from the given `allowedRanks` list
|
||||
class VectorOfRank<list<int> allowedRanks> : Type<
|
||||
IsVectorOfRankPred<allowedRanks>,
|
||||
" of ranks " # StrJoinInt<allowedRanks, "/">.result>;
|
||||
" of ranks " # StrJoinInt<allowedRanks, "/">.result, "::mlir::VectorType">;
|
||||
|
||||
// Any vector where the rank is from the given `allowedRanks` list and the type
|
||||
// is from the given `allowedTypes` list
|
||||
|
@ -543,7 +560,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
|
|||
And<[VectorOf<allowedTypes>.predicate,
|
||||
VectorOfRank<allowedRanks>.predicate]>,
|
||||
VectorOf<allowedTypes>.description #
|
||||
VectorOfRank<allowedRanks>.description>;
|
||||
VectorOfRank<allowedRanks>.description,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Whether the number of elements of a vector is from the given
|
||||
// `allowedLengths` list
|
||||
|
@ -558,7 +576,8 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
|
|||
// `allowedLengths` list
|
||||
class VectorOfLength<list<int> allowedLengths> : Type<
|
||||
IsVectorOfLengthPred<allowedLengths>,
|
||||
" of length " # StrJoinInt<allowedLengths, "/">.result>;
|
||||
" of length " # StrJoinInt<allowedLengths, "/">.result,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
|
||||
// Any vector where the number of elements is from the given
|
||||
|
@ -569,30 +588,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
|
|||
And<[VectorOf<allowedTypes>.predicate,
|
||||
VectorOfLength<allowedLengths>.predicate]>,
|
||||
VectorOf<allowedTypes>.description #
|
||||
VectorOfLength<allowedLengths>.description>;
|
||||
VectorOfLength<allowedLengths>.description,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
def AnyVector : VectorOf<[AnyType]>;
|
||||
|
||||
// Shaped types.
|
||||
|
||||
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
|
||||
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
|
||||
"::mlir::ShapedType">;
|
||||
|
||||
// Tensor types.
|
||||
|
||||
// Any tensor type whose element type is from the given `allowedTypes` list
|
||||
class TensorOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
|
||||
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
|
||||
"::mlir::TensorType">;
|
||||
|
||||
def AnyTensor : TensorOf<[AnyType]>;
|
||||
|
||||
def AnyRankedTensor :
|
||||
ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>,
|
||||
"ranked tensor">;
|
||||
"ranked tensor", "::mlir::TensorType">;
|
||||
|
||||
// TODO: Have an easy way to add another constraint to a type.
|
||||
class StaticShapeTensorOf<list<Type> allowedTypes>
|
||||
: Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
|
||||
"statically shaped " # TensorOf<allowedTypes>.description>;
|
||||
"statically shaped " # TensorOf<allowedTypes>.description,
|
||||
"::mlir::TensorType">;
|
||||
|
||||
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
|
||||
|
||||
|
@ -612,7 +635,7 @@ def F64Tensor : TensorOf<[F64]>;
|
|||
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
|
||||
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
|
||||
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
|
||||
TensorOf<allowedTypes>.description>;
|
||||
TensorOf<allowedTypes>.description, "::mlir::TensorType">;
|
||||
|
||||
class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
|
||||
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
|
||||
|
@ -623,12 +646,14 @@ class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
|
|||
// Unranked Memref type
|
||||
def AnyUnrankedMemRef :
|
||||
ShapedContainerType<[AnyType],
|
||||
IsUnrankedMemRefTypePred, "unranked.memref">;
|
||||
IsUnrankedMemRefTypePred, "unranked.memref",
|
||||
"::mlir::MemRefType">;
|
||||
// Memref type.
|
||||
|
||||
// Memrefs are blocks of data with fixed type and rank.
|
||||
class MemRefOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">;
|
||||
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
|
||||
"::mlir::MemRefType">;
|
||||
|
||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||
|
||||
|
@ -679,7 +704,7 @@ class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
|
|||
MemRefOf<allowedTypes>.description>;
|
||||
|
||||
// This represents a generic tuple without any constraints on element type.
|
||||
def AnyTuple : Type<IsTupleTypePred, "tuple">;
|
||||
def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
|
||||
|
||||
// A container type that has other types embedded in it, but (unlike
|
||||
// ContainerType) can hold elements with a mix of types. Requires a call that
|
||||
|
@ -2414,9 +2439,7 @@ def replaceWithValue;
|
|||
// the given C++ base class.
|
||||
class TypeDef<Dialect dialect, string name,
|
||||
string baseCppClass = "::mlir::Type">
|
||||
: DialectType<dialect, CPred<"">> {
|
||||
// The name of the C++ Type class.
|
||||
string cppClassName = name # "Type";
|
||||
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
|
||||
// The name of the C++ base class to use for this Type.
|
||||
string cppBaseClassName = baseCppClass;
|
||||
|
||||
|
|
|
@ -28,10 +28,6 @@ namespace mlir {
|
|||
class Builder;
|
||||
class OpBuilder;
|
||||
|
||||
namespace OpTrait {
|
||||
template <typename ConcreteType> class OneResult;
|
||||
}
|
||||
|
||||
/// This class represents success/failure for operation parsing. It is
|
||||
/// essentially a simple wrapper class around LogicalResult that allows for
|
||||
/// explicit conversion to bool. This allows for the parser to chain together
|
||||
|
@ -188,7 +184,8 @@ public:
|
|||
void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
|
||||
|
||||
/// Set the dialect attributes for this operation, and preserve all dependent.
|
||||
template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
|
||||
template <typename DialectAttrs>
|
||||
void setDialectAttrs(DialectAttrs &&attrs) {
|
||||
state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
|
||||
}
|
||||
|
||||
|
@ -424,7 +421,8 @@ public:
|
|||
///
|
||||
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class NOperands {
|
||||
template <unsigned N>
|
||||
class NOperands {
|
||||
public:
|
||||
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
|
||||
|
||||
|
@ -443,7 +441,8 @@ public:
|
|||
///
|
||||
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class AtLeastNOperands {
|
||||
template <unsigned N>
|
||||
class AtLeastNOperands {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
|
||||
|
@ -517,7 +516,8 @@ public:
|
|||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of regions.
|
||||
template <unsigned N> class NRegions {
|
||||
template <unsigned N>
|
||||
class NRegions {
|
||||
public:
|
||||
static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
|
||||
|
||||
|
@ -533,7 +533,8 @@ public:
|
|||
|
||||
/// This class provides APIs for ops that are known to have at least a specified
|
||||
/// number of regions.
|
||||
template <unsigned N> class AtLeastNRegions {
|
||||
template <unsigned N>
|
||||
class AtLeastNRegions {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public detail::MultiRegionTraitBase<ConcreteType,
|
||||
|
@ -582,7 +583,8 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
|||
|
||||
/// Replace all uses of results of this operation with the provided 'values'.
|
||||
/// 'values' may correspond to an existing operation, or a range of 'Value'.
|
||||
template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
|
||||
template <typename ValuesT>
|
||||
void replaceAllUsesWith(ValuesT &&values) {
|
||||
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
|
||||
}
|
||||
|
||||
|
@ -610,20 +612,19 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
|||
} // end namespace detail
|
||||
|
||||
/// This class provides return value APIs for ops that are known to have a
|
||||
/// single result.
|
||||
/// single result. ResultType is the concrete type returned by getType().
|
||||
template <typename ConcreteType>
|
||||
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||
public:
|
||||
Value getResult() { return this->getOperation()->getResult(0); }
|
||||
Type getType() { return getResult().getType(); }
|
||||
|
||||
/// If the operation returns a single value, then the Op can be implicitly
|
||||
/// converted to an Value. This yields the value of the only result.
|
||||
operator Value() { return getResult(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
/// Replace all uses of 'this' value with the new value, updating anything
|
||||
/// in the IR that uses 'this' to use the other value instead. When this
|
||||
/// returns there are zero uses of 'this'.
|
||||
void replaceAllUsesWith(Value newValue) {
|
||||
getResult().replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
@ -638,12 +639,33 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This trait is used for return value APIs for ops that are known to have a
|
||||
/// specific type other than `Type`. This allows the "getType()" member to be
|
||||
/// more specific for an op. This should be used in conjunction with OneResult,
|
||||
/// and occur in the trait list before OneResult.
|
||||
template <typename ResultType>
|
||||
class OneTypedResult {
|
||||
public:
|
||||
/// This class provides return value APIs for ops that are known to have a
|
||||
/// single result. ResultType is the concrete type returned by getType().
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
|
||||
public:
|
||||
ResultType getType() {
|
||||
auto resultTy = this->getOperation()->getResult(0).getType();
|
||||
return resultTy.template cast<ResultType>();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of results. This is used as a trait like this:
|
||||
///
|
||||
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class NResults {
|
||||
template <unsigned N>
|
||||
class NResults {
|
||||
public:
|
||||
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
|
||||
|
||||
|
@ -662,7 +684,8 @@ public:
|
|||
///
|
||||
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class AtLeastNResults {
|
||||
template <unsigned N>
|
||||
class AtLeastNResults {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public detail::MultiResultTraitBase<ConcreteType,
|
||||
|
@ -1573,7 +1596,8 @@ private:
|
|||
using has_fold = decltype(
|
||||
std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
|
||||
std::declval<SmallVectorImpl<OpFoldResult> &>()));
|
||||
template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
|
||||
template <typename T>
|
||||
using detect_has_fold = llvm::is_detected<has_fold, T>;
|
||||
/// Trait to check if T provides a 'print' method.
|
||||
template <typename T, typename... Args>
|
||||
using has_print =
|
||||
|
|
|
@ -47,6 +47,9 @@ public:
|
|||
// Returns the builder call for this constraint if this is a buildable type,
|
||||
// returns None otherwise.
|
||||
Optional<StringRef> getBuilderCall() const;
|
||||
|
||||
// Return the C++ class name for this type (which may just be ::mlir::Type).
|
||||
StringRef getCPPClassName() const;
|
||||
};
|
||||
|
||||
// Wrapper class with helper methods for accessing Types defined in TableGen.
|
||||
|
|
|
@ -612,7 +612,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = op->getResult(0).getType().cast<VectorType>();
|
||||
auto dstType = op.getType();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank == 1) {
|
||||
rewriter.replaceOp(
|
||||
|
@ -1091,8 +1091,7 @@ public:
|
|||
auto loc = castOp->getLoc();
|
||||
MemRefType sourceMemRefType =
|
||||
castOp.getOperand().getType().cast<MemRefType>();
|
||||
MemRefType targetMemRefType =
|
||||
castOp.getResult().getType().cast<MemRefType>();
|
||||
MemRefType targetMemRefType = castOp.getType();
|
||||
|
||||
// Only static shape casts supported atm.
|
||||
if (!sourceMemRefType.hasStaticShape() ||
|
||||
|
@ -1459,7 +1458,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getResult().getType().cast<VectorType>();
|
||||
auto dstType = op.getType();
|
||||
|
||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ static MaskFormat get1DMaskFormat(Value mask) {
|
|||
ArrayAttr masks = m.mask_dim_sizes();
|
||||
assert(masks.size() == 1);
|
||||
int64_t i = masks[0].cast<IntegerAttr>().getInt();
|
||||
int64_t u = m.getType().cast<VectorType>().getDimSize(0);
|
||||
int64_t u = m.getType().getDimSize(0);
|
||||
if (i >= u)
|
||||
return MaskFormat::AllTrue;
|
||||
if (i <= 0)
|
||||
|
@ -849,7 +849,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
|||
return Value();
|
||||
// Get the nth dimension size starting from lowest dimension.
|
||||
auto getDimReverse = [](VectorType type, int64_t n) {
|
||||
return type.getShape().take_back(n+1).front();
|
||||
return type.getShape().take_back(n + 1).front();
|
||||
};
|
||||
int64_t destinationRank =
|
||||
extractOp.getType().isa<VectorType>()
|
||||
|
@ -1870,9 +1870,8 @@ public:
|
|||
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
|
||||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr = DenseElementsAttr::get(
|
||||
extractStridedSliceOp.getType().cast<VectorType>(),
|
||||
dense.getSplatValue());
|
||||
auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
|
||||
dense.getSplatValue());
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -999,8 +999,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
|
|||
return failure();
|
||||
auto operandSourceVectorType =
|
||||
sourceShapeCastOp.source().getType().cast<VectorType>();
|
||||
auto operandResultVectorType =
|
||||
sourceShapeCastOp.result().getType().cast<VectorType>();
|
||||
auto operandResultVectorType = sourceShapeCastOp.getType();
|
||||
|
||||
// Check if shape cast operations invert each other.
|
||||
if (operandSourceVectorType != resultVectorType ||
|
||||
|
@ -1397,7 +1396,7 @@ public:
|
|||
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto dstType = op.getResult().getType().cast<VectorType>();
|
||||
auto dstType = op.getType();
|
||||
auto eltType = dstType.getElementType();
|
||||
auto dimSizes = op.mask_dim_sizes();
|
||||
int64_t rank = dimSizes.size();
|
||||
|
|
|
@ -53,6 +53,11 @@ Optional<StringRef> TypeConstraint::getBuilderCall() const {
|
|||
.Default([](auto *) { return llvm::None; });
|
||||
}
|
||||
|
||||
// Return the C++ class name for this type (which may just be ::mlir::Type).
|
||||
StringRef TypeConstraint::getCPPClassName() const {
|
||||
return def->getValueAsString("cppClassName");
|
||||
}
|
||||
|
||||
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
|
||||
|
||||
StringRef Type::getTypeDescription() const {
|
||||
|
|
|
@ -2137,11 +2137,18 @@ void OpEmitter::genTraits() {
|
|||
unsigned numVariadicRegions = op.getNumVariadicRegions();
|
||||
addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
|
||||
|
||||
// Add result size trait.
|
||||
// Add result size traits.
|
||||
int numResults = op.getNumResults();
|
||||
int numVariadicResults = op.getNumVariableLengthResults();
|
||||
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
|
||||
|
||||
// For single result ops with a known specific type, generate a OneTypedResult
|
||||
// trait.
|
||||
if (numResults == 1 && numVariadicResults == 0) {
|
||||
auto cppName = op.getResults().begin()->constraint.getCPPClassName();
|
||||
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
|
||||
}
|
||||
|
||||
// Add successor size trait.
|
||||
unsigned numSuccessors = op.getNumSuccessors();
|
||||
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
|
||||
|
|
Loading…
Reference in New Issue