[ODS] Make the getType() method on a OneResult instruction return a specific type.

Implement Bug 46698, making ODS synthesize a getType() method that returns a
specific C++ class for OneResult methods where we know that class.  This eliminates
a common source of casts in things like:

   myOp.getType().cast<FIRRTLType>().getPassive()

because we know that myOp always returns a FIRRTLType.  This also encourages
op authors to type their results more tightly (which is also good for
verification).

I chose to implement this by splitting the OneResult trait into itself plus a
OneTypedResult trait, given that many things are using `hasTrait<OneResult>`
to conditionalize various logic.

While this changes makes many many ops get more specific getType() results, it
is generally drop-in compatible with the previous behavior because 'x.cast<T>()'
is allowed when x is already known to be a T.  The one exception to this is that
we need declarations of the types used by ops, which is why a couple headers
needed additional #includes.

I updated a few things in tree to remove the now-redundant `.cast<>`'s, but there
are probably many more than can be removed.

Differential Revision: https://reviews.llvm.org/D93790
This commit is contained in:
Chris Lattner 2020-12-23 18:13:39 -08:00
parent 8791949f55
commit 9eb3e564d3
12 changed files with 137 additions and 73 deletions

View File

@ -210,7 +210,9 @@ class ConstantOp : public mlir::Op<ConstantOp,
/// The ConstantOp takes no inputs.
mlir::OpTrait::ZeroOperands,
/// The ConstantOp returns a single result.
mlir::OpTrait::OneResult> {
mlir::OpTrait::OneResult,
/// The result of getType is `Type`.
mlir::OpTraits::OneTypedResult<Type>::Impl> {
public:
/// Inherit the constructors from the base Op class.

View File

@ -9,6 +9,7 @@
#ifndef STANDALONE_STANDALONEOPS_H
#define STANDALONE_STANDALONEOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_
#define MLIR_DIALECT_AVX512_AVX512DIALECT_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -175,8 +175,12 @@ class Constraint<Pred pred, string desc = ""> {
// are considered as uncategorized constraints.
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
class TypeConstraint<Pred predicate, string description = "",
string cppClassNameParam = "::mlir::Type"> :
Constraint<predicate, description> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
}
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string description = ""> :
@ -285,8 +289,9 @@ class Dialect {
//===----------------------------------------------------------------------===//
// A type, carries type constraints.
class Type<Pred condition, string descr = ""> :
TypeConstraint<condition, descr> {
class Type<Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
TypeConstraint<condition, descr, cppClassName> {
string typeDescription = "";
string builderCall = "";
}
@ -299,8 +304,9 @@ class TypeAlias<Type t, string description = t.description> :
}
// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = ""> :
Type<condition, descr> {
class DialectType<Dialect d, Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
Type<condition, descr, cppClassName> {
Dialect dialect = d;
}
@ -331,11 +337,13 @@ class BuildableType<code builder> {
def AnyType : Type<CPred<"true">, "any type">;
// None type
def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type">,
def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
"::mlir::NoneType">,
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
class AnyTypeOf<list<Type> allowedTypes, string description = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed type's condition
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(description, ""),
@ -345,7 +353,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer">;
def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer",
"::mlir::IntegerType">;
// Any integer type (regardless of signedness semantics) of a specific width.
class AnyI<int width>
@ -355,7 +364,8 @@ class AnyI<int width>
class AnyIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, AnyI<w>),
StrJoinInt<widths, "/">.result # "-bit integer">;
StrJoinInt<widths, "/">.result # "-bit integer",
"::mlir::IntegerType">;
def AnyI1 : AnyI<1>;
def AnyI8 : AnyI<8>;
@ -365,12 +375,13 @@ def AnyI64 : AnyI<64>;
// Any signless integer type irrespective of its width.
def AnySignlessInteger : Type<
CPred<"$_self.isSignlessInteger()">, "signless integer">;
CPred<"$_self.isSignlessInteger()">, "signless integer",
"::mlir::IntegerType">;
// Signless integer type of a specific width.
class I<int width>
: Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
width # "-bit signless integer">,
width # "-bit signless integer", "::mlir::IntegerType">,
BuildableType<"$_builder.getIntegerType(" # width # ")"> {
int bitwidth = width;
}
@ -392,7 +403,7 @@ def AnySignedInteger : Type<
// Signed integer type of a specific width.
class SI<int width>
: Type<CPred<"$_self.isSignedInteger(" # width # ")">,
width # "-bit signed integer">,
width # "-bit signed integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
int bitwidth = width;
@ -415,7 +426,7 @@ def AnyUnsignedInteger : Type<
// Unsigned integer type of a specific width.
class UI<int width>
: Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
width # "-bit unsigned integer">,
width # "-bit unsigned integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
int bitwidth = width;
@ -432,18 +443,20 @@ def UI32 : UI<32>;
def UI64 : UI<64>;
// Index type.
def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index">,
def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
"::mlir::IndexType">,
BuildableType<"$_builder.getIndexType()">;
// Floating point types.
// Any float type irrespective of its width.
def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point">;
def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point",
"::mlir::FloatType">;
// Float type of a specific width.
class F<int width>
: Type<CPred<"$_self.isF" # width # "()">,
width # "-bit float">,
width # "-bit float", "::mlir::FloatType">,
BuildableType<"$_builder.getF" # width # "Type()"> {
int bitwidth = width;
}
@ -465,16 +478,17 @@ class Complex<Type type>
SubstLeaves<"$_self",
"$_self.cast<::mlir::ComplexType>().getElementType()",
type.predicate>]>,
"complex type with " # type.description # " elements"> {
"complex type with " # type.description # " elements",
"::mlir::ComplexType"> {
Type elementType = type;
}
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
"complex-type">;
"complex-type", "::mlir::ComplexType">;
class OpaqueType<string dialect, string name, string description>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
description>,
description, "::mlir::OpaqueType">,
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
"$_builder.getIdentifier(\"" # dialect # "\"), \""
# name # "\")">;
@ -483,17 +497,17 @@ class OpaqueType<string dialect, string name, string description>
// Any function type.
def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">,
"function type">;
"function type", "::mlir::FunctionType">;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr> :
string descr, string cppClassName = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.description # " values"> {
descr # " of " # etype.description # " values", cppClassName> {
// The type of elements in the container.
Type elementType = etype;
@ -502,9 +516,11 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
}
class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr> :
Pred containerPred, string descr,
string cppClassName = "::mlir::Type"> :
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
"$_self.cast<::mlir::ShapedType>().getElementType()", descr>;
"$_self.cast<::mlir::ShapedType>().getElementType()", descr,
cppClassName>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
@ -520,7 +536,8 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
// Vector types.
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
@ -534,7 +551,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
" of ranks " # StrJoinInt<allowedRanks, "/">.result>;
" of ranks " # StrJoinInt<allowedRanks, "/">.result, "::mlir::VectorType">;
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
@ -543,7 +560,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
And<[VectorOf<allowedTypes>.predicate,
VectorOfRank<allowedRanks>.predicate]>,
VectorOf<allowedTypes>.description #
VectorOfRank<allowedRanks>.description>;
VectorOfRank<allowedRanks>.description,
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedLengths` list
@ -558,7 +576,8 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
IsVectorOfLengthPred<allowedLengths>,
" of length " # StrJoinInt<allowedLengths, "/">.result>;
" of length " # StrJoinInt<allowedLengths, "/">.result,
"::mlir::VectorType">;
// Any vector where the number of elements is from the given
@ -569,30 +588,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
And<[VectorOf<allowedTypes>.predicate,
VectorOfLength<allowedLengths>.predicate]>,
VectorOf<allowedTypes>.description #
VectorOfLength<allowedLengths>.description>;
VectorOfLength<allowedLengths>.description,
"::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
// Shaped types.
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
"::mlir::TensorType">;
def AnyTensor : TensorOf<[AnyType]>;
def AnyRankedTensor :
ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>,
"ranked tensor">;
"ranked tensor", "::mlir::TensorType">;
// TODO: Have an easy way to add another constraint to a type.
class StaticShapeTensorOf<list<Type> allowedTypes>
: Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
"statically shaped " # TensorOf<allowedTypes>.description>;
"statically shaped " # TensorOf<allowedTypes>.description,
"::mlir::TensorType">;
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
@ -612,7 +635,7 @@ def F64Tensor : TensorOf<[F64]>;
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
TensorOf<allowedTypes>.description>;
TensorOf<allowedTypes>.description, "::mlir::TensorType">;
class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
@ -623,12 +646,14 @@ class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
// Unranked Memref type
def AnyUnrankedMemRef :
ShapedContainerType<[AnyType],
IsUnrankedMemRefTypePred, "unranked.memref">;
IsUnrankedMemRefTypePred, "unranked.memref",
"::mlir::MemRefType">;
// Memref type.
// Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref">;
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
"::mlir::MemRefType">;
def AnyMemRef : MemRefOf<[AnyType]>;
@ -679,7 +704,7 @@ class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
MemRefOf<allowedTypes>.description>;
// This represents a generic tuple without any constraints on element type.
def AnyTuple : Type<IsTupleTypePred, "tuple">;
def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
// A container type that has other types embedded in it, but (unlike
// ContainerType) can hold elements with a mix of types. Requires a call that
@ -2414,9 +2439,7 @@ def replaceWithValue;
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">> {
// The name of the C++ Type class.
string cppClassName = name # "Type";
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type"> {
// The name of the C++ base class to use for this Type.
string cppBaseClassName = baseCppClass;

View File

@ -28,10 +28,6 @@ namespace mlir {
class Builder;
class OpBuilder;
namespace OpTrait {
template <typename ConcreteType> class OneResult;
}
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
@ -188,7 +184,8 @@ public:
void setAttrs(DictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
/// Set the dialect attributes for this operation, and preserve all dependent.
template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
template <typename DialectAttrs>
void setDialectAttrs(DialectAttrs &&attrs) {
state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
}
@ -424,7 +421,8 @@ public:
///
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N> class NOperands {
template <unsigned N>
class NOperands {
public:
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
@ -443,7 +441,8 @@ public:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
///
template <unsigned N> class AtLeastNOperands {
template <unsigned N>
class AtLeastNOperands {
public:
template <typename ConcreteType>
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
@ -517,7 +516,8 @@ public:
/// This class provides the API for ops that are known to have a specified
/// number of regions.
template <unsigned N> class NRegions {
template <unsigned N>
class NRegions {
public:
static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
@ -533,7 +533,8 @@ public:
/// This class provides APIs for ops that are known to have at least a specified
/// number of regions.
template <unsigned N> class AtLeastNRegions {
template <unsigned N>
class AtLeastNRegions {
public:
template <typename ConcreteType>
class Impl : public detail::MultiRegionTraitBase<ConcreteType,
@ -582,7 +583,8 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
template <typename ValuesT>
void replaceAllUsesWith(ValuesT &&values) {
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
}
@ -610,20 +612,19 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
} // end namespace detail
/// This class provides return value APIs for ops that are known to have a
/// single result.
/// single result. ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value getResult() { return this->getOperation()->getResult(0); }
Type getType() { return getResult().getType(); }
/// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value. This yields the value of the only result.
operator Value() { return getResult(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
/// Replace all uses of 'this' value with the new value, updating anything
/// in the IR that uses 'this' to use the other value instead. When this
/// returns there are zero uses of 'this'.
void replaceAllUsesWith(Value newValue) {
getResult().replaceAllUsesWith(newValue);
}
@ -638,12 +639,33 @@ public:
}
};
/// This trait is used for return value APIs for ops that are known to have a
/// specific type other than `Type`. This allows the "getType()" member to be
/// more specific for an op. This should be used in conjunction with OneResult,
/// and occur in the trait list before OneResult.
template <typename ResultType>
class OneTypedResult {
public:
/// This class provides return value APIs for ops that are known to have a
/// single result. ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
ResultType getType() {
auto resultTy = this->getOperation()->getResult(0).getType();
return resultTy.template cast<ResultType>();
}
};
};
/// This class provides the API for ops that are known to have a specified
/// number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N> class NResults {
template <unsigned N>
class NResults {
public:
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
@ -662,7 +684,8 @@ public:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
///
template <unsigned N> class AtLeastNResults {
template <unsigned N>
class AtLeastNResults {
public:
template <typename ConcreteType>
class Impl : public detail::MultiResultTraitBase<ConcreteType,
@ -1573,7 +1596,8 @@ private:
using has_fold = decltype(
std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
template <typename T>
using detect_has_fold = llvm::is_detected<has_fold, T>;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
using has_print =

View File

@ -47,6 +47,9 @@ public:
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;
// Return the C++ class name for this type (which may just be ::mlir::Type).
StringRef getCPPClassName() const;
};
// Wrapper class with helper methods for accessing Types defined in TableGen.

View File

@ -612,7 +612,7 @@ public:
LogicalResult
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op->getResult(0).getType().cast<VectorType>();
auto dstType = op.getType();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
@ -1091,8 +1091,7 @@ public:
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
castOp.getOperand().getType().cast<MemRefType>();
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
MemRefType targetMemRefType = castOp.getType();
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
@ -1459,7 +1458,7 @@ public:
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
auto dstType = op.getType();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");

View File

@ -67,7 +67,7 @@ static MaskFormat get1DMaskFormat(Value mask) {
ArrayAttr masks = m.mask_dim_sizes();
assert(masks.size() == 1);
int64_t i = masks[0].cast<IntegerAttr>().getInt();
int64_t u = m.getType().cast<VectorType>().getDimSize(0);
int64_t u = m.getType().getDimSize(0);
if (i >= u)
return MaskFormat::AllTrue;
if (i <= 0)
@ -849,7 +849,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getShape().take_back(n+1).front();
return type.getShape().take_back(n + 1).front();
};
int64_t destinationRank =
extractOp.getType().isa<VectorType>()
@ -1870,9 +1870,8 @@ public:
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
auto newAttr = DenseElementsAttr::get(
extractStridedSliceOp.getType().cast<VectorType>(),
dense.getSplatValue());
auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
dense.getSplatValue());
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
return success();
}

View File

@ -999,8 +999,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
return failure();
auto operandSourceVectorType =
sourceShapeCastOp.source().getType().cast<VectorType>();
auto operandResultVectorType =
sourceShapeCastOp.result().getType().cast<VectorType>();
auto operandResultVectorType = sourceShapeCastOp.getType();
// Check if shape cast operations invert each other.
if (operandSourceVectorType != resultVectorType ||
@ -1397,7 +1396,7 @@ public:
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getResult().getType().cast<VectorType>();
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size();

View File

@ -53,6 +53,11 @@ Optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return llvm::None; });
}
// Return the C++ class name for this type (which may just be ::mlir::Type).
StringRef TypeConstraint::getCPPClassName() const {
return def->getValueAsString("cppClassName");
}
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
StringRef Type::getTypeDescription() const {

View File

@ -2137,11 +2137,18 @@ void OpEmitter::genTraits() {
unsigned numVariadicRegions = op.getNumVariadicRegions();
addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
// Add result size trait.
// Add result size traits.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
auto cppName = op.getResults().begin()->constraint.getCPPClassName();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}
// Add successor size trait.
unsigned numSuccessors = op.getNumSuccessors();
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();