diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 3c2ad549cf2b..3deddd0dfa4a 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -158,6 +158,9 @@ def IsVectorOrTensorTypePred : CPred<"{0}.isa()">; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"{0}.isa()">; +// Whether a type is a MemRefType. +def IsMemRefTypePred : CPred<"{0}.isa()">; + // For a TensorType, verify that it is a statically shaped tensor. def IsStaticShapeTensorTypePred : CPred<"{0}.cast().hasStaticShape()">; @@ -318,6 +321,26 @@ def F64Tensor : TypedTensor; // there is not only a single elemental type. def Tuple : Type; +// Memref type. + +// Memrefs are blocks of data with fixed type and rank. +class MemRef + : ContainerType().getElementType()", "memref">; + +// Memref declarations handle any memref, independent of rank, size, (static or +// dynamic), layout, or memory space. +def I1MemRef : MemRef; +def I8MemRef : MemRef; +def I16MemRef : MemRef; +def I32MemRef : MemRef; +def I64MemRef : MemRef; + +def BF16MemRef : MemRef; +def F16MemRef : MemRef; +def F32MemRef : MemRef; +def F64MemRef : MemRef; + //===----------------------------------------------------------------------===// // Common type constraints //===----------------------------------------------------------------------===//