Add UnrankedMemRef Type

Closes tensorflow/mlir#261

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/261 from nmostafa:nmostafa/unranked 96b6e918f6ed64496f7573b2db33c0b02658ca45
PiperOrigin-RevId: 284037040
This commit is contained in:
nmostafa 2019-12-05 13:12:50 -08:00 committed by A. Unique TensorFlower
parent e67acfa468
commit daff60cd68
23 changed files with 678 additions and 118 deletions

View File

@ -90,6 +90,23 @@ memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, float*, i64, [5 x i64], [5
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }"> memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }">
``` ```
If the rank of the memref is unknown at compile time, the Memref is converted to
an unranked descriptor that contains: 1. a 64-bit integer representing the
dynamic rank of the memref, followed by 2. a pointer to a ranked memref
descriptor with the contents listed above.
Dynamic ranked memrefs should be used only to pass arguments to external library
calls that expect a unified memref type. The called functions can parse any
unranked memref descriptor by reading the rank and parsing the enclosed ranked
descriptor pointer.
Examples:
```mlir {.mlir}
// unranked descriptor
memref<*xf32> -> !llvm.type<"{i64, i8*}">
```
### Function Types ### Function Types
Function types get converted to LLVM function types. The arguments are converted Function types get converted to LLVM function types. The arguments are converted

View File

@ -912,12 +912,21 @@ Examples:
// Convert to a type with more known dimensions. // Convert to a type with more known dimensions.
%4 = memref_cast %3 : memref<?x?xf32> to memref<4x?xf32> %4 = memref_cast %3 : memref<?x?xf32> to memref<4x?xf32>
// Convert to a type with unknown rank.
%5 = memref_cast %3 : memref<?x?xf32> to memref<*xf32>
// Convert to a type with static rank.
%6 = memref_cast %5 : memref<*xf32> to memref<?x?xf32>
``` ```
Convert a memref from one type to an equivalent type without changing any data Convert a memref from one type to an equivalent type without changing any data
elements. The source and destination types must both be memref types with the elements. The types are equivalent if 1. they both have the same static rank,
same element type, same mappings, same address space, and same rank. The same element type, same mappings, same address space. The operation is invalid
operation is invalid if converting to a mismatching constant dimension. if converting to a mismatching constant dimension, or 2. exactly one of the
operands have an unknown rank, and they both have the same element type and same
address space. The operation is invalid if both operands are of dynamic rank or
if converting to a mismatching static rank.
### 'mulf' operation ### 'mulf' operation

View File

@ -760,10 +760,16 @@ TODO: Need to decide on a representation for quantized integers
Syntax: Syntax:
``` {.ebnf} ``` {.ebnf}
memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
memref-type ::= ranked-memref-type | unranked-memref-type
ranked-memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
(`,` layout-specification)? | (`,` layout-specification)? |
(`,` memory-space)? `>` (`,` memory-space)? `>`
unranked-memref-type ::= `memref` `<*x` tensor-memref-element-type
(`,` memory-space)? `>`
stride-list ::= `[` (dimension (`,` dimension)*)? `]` stride-list ::= `[` (dimension (`,` dimension)*)? `]`
strided-layout ::= `offset:` dimension `,` `strides: ` stride-list strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
layout-specification ::= semi-affine-map | strided-layout layout-specification ::= semi-affine-map | strided-layout
@ -774,9 +780,48 @@ A `memref` type is a reference to a region of memory (similar to a buffer
pointer, but more powerful). The buffer pointed to by a memref can be allocated, pointer, but more powerful). The buffer pointed to by a memref can be allocated,
aliased and deallocated. A memref can be used to read and write data from/to the aliased and deallocated. A memref can be used to read and write data from/to the
memory region which it references. Memref types use the same shape specifier as memory region which it references. Memref types use the same shape specifier as
tensor types, but do not allow unknown rank. Note that `memref<f32>`, `memref<0 tensor types. Note that `memref<f32>`, `memref<0 x f32>`, `memref<1 x 0 x f32>`,
x f32>`, `memref<1 x 0 x f32>`, and `memref<0 x 1 x f32>` are all different and `memref<0 x 1 x f32>` are all different types.
types.
A `memref` is allowed to have an unknown rank (e.g. `memref<*xf32>`). The
purpose of unranked memrefs is to allow external library functions to receive
memref arguments of any rank without versioning the functions based on the rank.
Other uses of this type are disallowed or will have undefined behavior.
##### Codegen of Unranked Memref
Using unranked memref in codegen besides the case mentioned above is highly
discouraged. Codegen is concerned with generating loop nests and specialized
instructions for high-performance, unranked memref is concerned with hiding the
rank and thus, the number of enclosing loops required to iterate over the data.
However, if there is a need to code-gen unranked memref, one possible path is to
cast into a static ranked type based on the dynamic rank. Another possible path
is to emit a single while loop conditioned on a linear index and perform
delinearization of the linear index to a dynamic array containing the (unranked)
indices. While this is possible, it is expected to not be a good idea to perform
this during codegen as the cost of the translations is expected to be
prohibitive and optimizations at this level are not expected to be worthwhile.
If expressiveness is the main concern, irrespective of performance, passing
unranked memrefs to an external C++ library and implementing rank-agnostic logic
there is expected to be significantly simpler.
Unranked memrefs may provide expressiveness gains in the future and help bridge
the gap with unranked tensors. Unranked memrefs will not be expected to be
exposed to codegen but one may query the rank of an unranked memref (a special
op will be needed for this purpose) and perform a switch and cast to a ranked
memref as a prerequisite to codegen.
Example ```mlir {.mlir} // With static ranks, we need a function for each
possible argument type %A = alloc() : memref<16x32xf32> %B = alloc() :
memref<16x32x64xf32> call @helper_2D(%A) : (memref<16x32xf32>)->() call
@helper_3D(%B) : (memref<16x32x64xf32>)->()
// With unknown rank, the functions can be unified under one unranked type %A =
alloc() : memref<16x32xf32> %B = alloc() : memref<16x32x64xf32> // Remove rank
info %A_u = memref_cast %A : memref<16x32xf32> -> memref<*xf32> %B_u =
memref_cast %B : memref<16x32x64xf32> -> memref<*xf32> // call same function
with dynamic ranks call @helper(%A_u) : (memref<*xf32>)->() call @helper(%B_u) :
(memref<*xf32>)->() ```
The core syntax and representation of a layout specification is a The core syntax and representation of a layout specification is a
[semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic [semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic

View File

@ -34,6 +34,9 @@ class Type;
} // namespace llvm } // namespace llvm
namespace mlir { namespace mlir {
class UnrankedMemRefType;
namespace LLVM { namespace LLVM {
class LLVMDialect; class LLVMDialect;
class LLVMType; class LLVMType;
@ -116,6 +119,10 @@ private:
// 2. as many index types as memref has dynamic dimensions. // 2. as many index types as memref has dynamic dimensions.
Type convertMemRefType(MemRefType type); Type convertMemRefType(MemRefType type);
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
// Convert a 1D vector type into an LLVM vector type. // Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type); Type convertVectorType(VectorType type);
@ -127,10 +134,34 @@ private:
LLVM::LLVMType unwrap(Type type); LLVM::LLVMType unwrap(Type type);
}; };
/// Helper class to produce LLVM dialect operations extracting or inserting
/// values to a struct.
class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value *v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/*implicit*/ operator Value *() { return value; }
protected:
// LLVM value
Value *value;
// Cached struct type.
Type structType;
protected:
/// Builds IR to extract a value from the struct at position pos
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
};
/// Helper class to produce LLVM dialect operations extracting or inserting /// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid. /// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor { class MemRefDescriptor : public StructBuilder {
public: public:
/// Construct a helper for the given descriptor value. /// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value *descriptor); explicit MemRefDescriptor(Value *descriptor);
@ -169,22 +200,28 @@ public:
/// Returns the (LLVM) type this descriptor points to. /// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType(); LLVM::LLVMType getElementType();
/*implicit*/ operator Value *() { return value; }
private: private:
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
// Cached descriptor type.
Type structType;
// Cached index type. // Cached index type.
Type indexType; Type indexType;
// Actual descriptor.
Value *value;
}; };
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value *descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the rank from the descriptor
Value *rank(OpBuilder &builder, Location loc);
/// Builds IR setting the rank in the descriptor
void setRank(OpBuilder &builder, Location loc, Value *value);
/// Builds IR extracting ranked memref descriptor ptr
Value *memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value *value);
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides /// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the /// conversion patterns with an access to the containing LLVMLowering for the
/// purpose of type conversions. /// purpose of type conversions.

View File

@ -842,7 +842,8 @@ def MemRefCastOp : CastOp<"memref_cast"> {
let description = [{ let description = [{
The "memref_cast" operation converts a memref from one type to an equivalent The "memref_cast" operation converts a memref from one type to an equivalent
type with a compatible shape. The source and destination types are type with a compatible shape. The source and destination types are
when both are memref types with the same element type, affine mappings, compatible if:
a. both are ranked memref types with the same element type, affine mappings,
address space, and rank but where the individual dimensions may add or address space, and rank but where the individual dimensions may add or
remove constant dimensions from the memref type. remove constant dimensions from the memref type.
@ -850,6 +851,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
acts as an assertion that fails at runtime of the dynamic dimensions acts as an assertion that fails at runtime of the dynamic dimensions
disagree with resultant destination size. disagree with resultant destination size.
Example:
Assert that the input dynamic shape matches the destination static shape. Assert that the input dynamic shape matches the destination static shape.
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32> %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
Erase static shape information, replacing it with dynamic information. Erase static shape information, replacing it with dynamic information.
@ -864,10 +866,20 @@ def MemRefCastOp : CastOp<"memref_cast"> {
dynamic information. dynamic information.
%5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to %5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to
memref<12x4xf32, offset:?, strides: [?, ?]> memref<12x4xf32, offset:?, strides: [?, ?]>
b. either or both memref types are unranked with the same element type, and
address space.
Example:
Cast to concrete shape.
%4 = memref_cast %1 : memref<*xf32> to memref<4x?xf32>
Erase rank information.
%5 = memref_cast %1 : memref<4x?xf32> to memref<*xf32>
}]; }];
let arguments = (ins AnyMemRef:$source); let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyMemRef); let results = (outs AnyRankedOrUnrankedMemRef);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for /// Return true if `a` and `b` are valid operand and result pairs for
@ -875,7 +887,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
static bool areCastCompatible(Type a, Type b); static bool areCastCompatible(Type a, Type b);
/// The result of a memref_cast is always a memref. /// The result of a memref_cast is always a memref.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); } Type getType() { return getResult()->getType(); }
}]; }];
} }

View File

@ -221,6 +221,9 @@ def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
// Whether a type is a MemRefType. // Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">; def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
// Whether a type is an IsUnrankedMemRefType
def IsUnrankedMemRefTypePred : CPred<"$_self.isa<UnrankedMemRefType>()">;
// Whether a type is a ShapedType. // Whether a type is a ShapedType.
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">; def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
@ -486,6 +489,10 @@ class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>; class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>; class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
// Unranked Memref type
def AnyUnrankedMemRef :
ShapedContainerType<[AnyType],
IsUnrankedMemRefTypePred, "unranked.memref">;
// Memref type. // Memref type.
// Memrefs are blocks of data with fixed type and rank. // Memrefs are blocks of data with fixed type and rank.
@ -494,6 +501,8 @@ class MemRefOf<list<Type> allowedTypes> :
def AnyMemRef : MemRefOf<[AnyType]>; def AnyMemRef : MemRefOf<[AnyType]>;
def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
// Memref declarations handle any memref, independent of rank, size, (static or // Memref declarations handle any memref, independent of rank, size, (static or
// dynamic), layout, or memory space. // dynamic), layout, or memory space.
def I1MemRef : MemRefOf<[I1]>; def I1MemRef : MemRefOf<[I1]>;

View File

@ -40,6 +40,7 @@ struct VectorTypeStorage;
struct RankedTensorTypeStorage; struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage; struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage; struct MemRefTypeStorage;
struct UnrankedMemRefTypeStorage;
struct ComplexTypeStorage; struct ComplexTypeStorage;
struct TupleTypeStorage; struct TupleTypeStorage;
@ -64,6 +65,7 @@ enum Kind {
RankedTensor, RankedTensor,
UnrankedTensor, UnrankedTensor,
MemRef, MemRef,
UnrankedMemRef,
Complex, Complex,
Tuple, Tuple,
None, None,
@ -243,6 +245,7 @@ public:
return type.getKind() == StandardTypes::Vector || return type.getKind() == StandardTypes::Vector ||
type.getKind() == StandardTypes::RankedTensor || type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor || type.getKind() == StandardTypes::UnrankedTensor ||
type.getKind() == StandardTypes::UnrankedMemRef ||
type.getKind() == StandardTypes::MemRef; type.getKind() == StandardTypes::MemRef;
} }
@ -370,12 +373,24 @@ public:
} }
}; };
/// Base MemRef for Ranked and Unranked variants
class BaseMemRefType : public ShapedType {
public:
using ShapedType::ShapedType;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::MemRef ||
type.getKind() == StandardTypes::UnrankedMemRef;
}
};
/// MemRef types represent a region of memory that have a shape with a fixed /// MemRef types represent a region of memory that have a shape with a fixed
/// number of dimensions. Each shape element can be a non-negative integer or /// number of dimensions. Each shape element can be a non-negative integer or
/// unknown (represented by any negative integer). MemRef types also have an /// unknown (represented by any negative integer). MemRef types also have an
/// affine map composition, represented as an array AffineMap pointers. /// affine map composition, represented as an array AffineMap pointers.
class MemRefType class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
: public Type::TypeBase<MemRefType, ShapedType, detail::MemRefTypeStorage> { detail::MemRefTypeStorage> {
public: public:
using Base::Base; using Base::Base;
@ -426,6 +441,40 @@ private:
using Base::getImpl; using Base::getImpl;
}; };
/// Unranked MemRef type represent multi-dimensional MemRefs that
/// have an unknown rank.
class UnrankedMemRefType
: public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
detail::UnrankedMemRefTypeStorage> {
public:
using Base::Base;
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space
static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space declared at the given, potentially unknown,
/// location. If the UnrankedMemRefType defined by the arguments would be
/// ill-formed, emit errors and return a nullptr-wrapping type.
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
Location location);
/// Verify the construction of a unranked memref type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Type elementType,
unsigned memorySpace);
ArrayRef<int64_t> getShape() const { return llvm::None; }
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
static bool kindof(unsigned kind) {
return kind == StandardTypes::UnrankedMemRef;
}
};
/// Tuple types represent a collection of other types. Note: This type merely /// Tuple types represent a collection of other types. Note: This type merely
/// provides a common mechanism for representing tuples in MLIR. It is up to /// provides a common mechanism for representing tuples in MLIR. It is up to
/// dialect authors to provides operations for manipulating them, e.g. /// dialect authors to provides operations for manipulating them, e.g.

View File

@ -193,6 +193,22 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
} }
// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
// contains:
// 1. int64_t rank, the dynamic rank of this MemRef
// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
// stack allocated (alloca) copy of a MemRef descriptor that got casted to
// be unranked.
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1. // n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
@ -221,6 +237,8 @@ Type LLVMTypeConverter::convertStandardType(Type type) {
return convertIndexType(indexType); return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>()) if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType); return convertMemRefType(memRefType);
if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
return convertUnrankedMemRefType(memRefType);
if (auto vectorType = type.dyn_cast<VectorType>()) if (auto vectorType = type.dyn_cast<VectorType>())
return convertVectorType(vectorType); return convertVectorType(vectorType);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
@ -245,22 +263,42 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
PatternBenefit benefit) PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
/*============================================================================*/
/* StructBuilder implementation */
/*============================================================================*/
StructBuilder::StructBuilder(Value *v) : value(v) {
assert(value != nullptr && "value cannot be null");
structType = value->getType().cast<LLVM::LLVMType>();
}
Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
/*============================================================================*/ /*============================================================================*/
/* MemRefDescriptor implementation */ /* MemRefDescriptor implementation */
/*============================================================================*/ /*============================================================================*/
/// Construct a helper for the given descriptor value. /// Construct a helper for the given descriptor value.
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) { MemRefDescriptor::MemRefDescriptor(Value *descriptor)
if (value) { : StructBuilder(descriptor) {
structType = value->getType().cast<LLVM::LLVMType>(); assert(value != nullptr && "value cannot be null");
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
kOffsetPosInMemRefDescriptor); kOffsetPosInMemRefDescriptor);
}
} }
/// Builds IR creating an `undef` value of the descriptor type. /// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) { Type descriptorType) {
Value *descriptor = Value *descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor); return MemRefDescriptor(descriptor);
@ -334,24 +372,42 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
} }
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
LLVM::LLVMType MemRefDescriptor::getElementType() { LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType( return value->getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor); kAlignedPtrPosInMemRefDescriptor);
} }
/*============================================================================*/
/* UnrankedMemRefDescriptor implementation */
/*============================================================================*/
/// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor)
: StructBuilder(descriptor) {}
/// Builds IR creating an `undef` value of the descriptor type.
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value *descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return UnrankedMemRefDescriptor(descriptor);
}
Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
Value *v) {
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
}
Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
Location loc) {
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
Location loc, Value *v) {
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
}
namespace { namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type // Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in // provided as template argument. Carries a reference to the LLVM dialect in
@ -432,7 +488,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>(); auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted) if (!converted)
return matchFailure(); return matchFailure();
if (t.isa<MemRefType>()) { if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) {
converted = converted.getPointerTo(); converted = converted.getPointerTo();
promotedArgIndices.push_back(en.index()); promotedArgIndices.push_back(en.index());
} }
@ -983,6 +1039,14 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
Type packedResult; Type packedResult;
unsigned numResults = callOp.getNumResults(); unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
for (Type resType : resultTypes) {
assert(!resType.isa<UnrankedMemRefType>() &&
"Returning unranked memref is not supported. Pass result as an"
"argument instead.");
(void)resType;
}
if (numResults != 0) { if (numResults != 0) {
if (!(packedResult = this->lowering.packFunctionResults(resultTypes))) if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
return this->matchFailure(); return this->matchFailure();
@ -1076,25 +1140,93 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
PatternMatchResult match(Operation *op) const override { PatternMatchResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op); auto memRefCastOp = cast<MemRefCastOp>(op);
Type srcType = memRefCastOp.getOperand()->getType();
Type dstType = memRefCastOp.getType();
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
MemRefType sourceType = MemRefType sourceType =
memRefCastOp.getOperand()->getType().cast<MemRefType>(); memRefCastOp.getOperand()->getType().cast<MemRefType>();
MemRefType targetType = memRefCastOp.getType(); MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
return (isSupportedMemRefType(targetType) && return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType)) isSupportedMemRefType(sourceType))
? matchSuccess() ? matchSuccess()
: matchFailure(); : matchFailure();
} }
// At least one of the operands is unranked type
assert(srcType.isa<UnrankedMemRefType>() ||
dstType.isa<UnrankedMemRefType>());
// Unranked to unranked cast is disallowed
return !(srcType.isa<UnrankedMemRefType>() &&
dstType.isa<UnrankedMemRefType>())
? matchSuccess()
: matchFailure();
}
void rewrite(Operation *op, ArrayRef<Value *> operands, void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op); auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands); OperandAdaptor<MemRefCastOp> transformed(operands);
auto srcType = memRefCastOp.getOperand()->getType();
auto dstType = memRefCastOp.getType();
auto targetStructType = lowering.convertType(memRefCastOp.getType());
auto loc = op->getLoc();
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
// memref_cast is defined for source and destination memref types with the // memref_cast is defined for source and destination memref types with the
// same element type, same mappings, same address space and same rank. // same element type, same mappings, same address space and same rank.
// Therefore a simple bitcast suffices. If not it is undefined behavior. // Therefore a simple bitcast suffices. If not it is undefined behavior.
auto targetStructType = lowering.convertType(memRefCastOp.getType());
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType, rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
transformed.source()); transformed.source());
} else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
// Casting ranked to unranked memref type
// Set the rank in the destination from the memref type
// Allocate space on the stack and copy the src memref decsriptor
// Set the ptr in the destination to the stack space
auto srcMemRefType = srcType.cast<MemRefType>();
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
rewriter);
// voidptr = BitCastOp srcType* to void*
auto voidPtr =
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
.getResult();
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(rewriter.getIntegerType(64)),
rewriter.getI64IntegerAttr(rank));
// undef = UndefOp
UnrankedMemRefDescriptor memRefDesc =
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
// d1 = InsertValueOp undef, rank, 0
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, voidptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
rewriter.replaceOp(op, (Value *)memRefDesc);
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
// Casting from unranked type to ranked.
// The operation is assumed to be doing a correct cast. If the destination
// type mismatches the unranked the type, it is undefined behavior.
UnrankedMemRefDescriptor memRefDesc(transformed.source());
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
auto castPtr =
rewriter
.create<LLVM::BitcastOp>(
loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
ptr)
.getResult();
// struct = LoadOp castPtr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
rewriter.replaceOp(op, loadOp.getResult());
} else {
llvm_unreachable("Unsuppored unranked memref to unranked memref cast");
}
} }
}; };
@ -1896,7 +2028,8 @@ SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
for (auto it : llvm::zip(opOperands, operands)) { for (auto it : llvm::zip(opOperands, operands)) {
auto *operand = std::get<0>(it); auto *operand = std::get<0>(it);
auto *llvmOperand = std::get<1>(it); auto *llvmOperand = std::get<1>(it);
if (!operand->getType().isa<MemRefType>()) { if (!operand->getType().isa<MemRefType>() &&
!operand->getType().isa<UnrankedMemRefType>()) {
promotedOperands.push_back(operand); promotedOperands.push_back(operand);
continue; continue;
} }

View File

@ -1769,8 +1769,10 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
auto aT = a.dyn_cast<MemRefType>(); auto aT = a.dyn_cast<MemRefType>();
auto bT = b.dyn_cast<MemRefType>(); auto bT = b.dyn_cast<MemRefType>();
if (!aT || !bT) auto uaT = a.dyn_cast<UnrankedMemRefType>();
return false; auto ubT = b.dyn_cast<UnrankedMemRefType>();
if (aT && bT) {
if (aT.getElementType() != bT.getElementType()) if (aT.getElementType() != bT.getElementType())
return false; return false;
if (aT.getAffineMaps() != bT.getAffineMaps()) { if (aT.getAffineMaps() != bT.getAffineMaps()) {
@ -1783,8 +1785,8 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
// Strides along a dimension/offset are compatible if the value in the // Strides along a dimension/offset are compatible if the value in the
// source memref is static and the value in the target memref is the // source memref is static and the value in the target memref is the
// same. They are also compatible if either one is dynamic (see description // same. They are also compatible if either one is dynamic (see
// of MemRefCastOp for details). // description of MemRefCastOp for details).
auto checkCompatible = [](int64_t a, int64_t b) { auto checkCompatible = [](int64_t a, int64_t b) {
return (a == MemRefType::getDynamicStrideOrOffset() || return (a == MemRefType::getDynamicStrideOrOffset() ||
b == MemRefType::getDynamicStrideOrOffset() || a == b); b == MemRefType::getDynamicStrideOrOffset() || a == b);
@ -1807,8 +1809,30 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
if (aDim != -1 && bDim != -1 && aDim != bDim) if (aDim != -1 && bDim != -1 && aDim != bDim)
return false; return false;
} }
return true;
} else {
if (!aT && !uaT)
return false;
if (!bT && !ubT)
return false;
// Unranked to unranked casting is unsupported
if (uaT && ubT)
return false;
auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
if (aEltType != bEltType)
return false;
auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
if (aMemSpace != bMemSpace)
return false;
return true; return true;
}
return false;
} }
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {

View File

@ -1086,6 +1086,13 @@ void ModulePrinter::printType(Type type) {
os << '>'; os << '>';
return; return;
} }
case StandardTypes::UnrankedMemRef: {
auto v = type.cast<UnrankedMemRefType>();
os << "memref<*x";
printType(v.getElementType());
os << '>';
return;
}
case StandardTypes::Complex: case StandardTypes::Complex:
os << "complex<"; os << "complex<";
printType(type.cast<ComplexType>().getElementType()); printType(type.cast<ComplexType>().getElementType());

View File

@ -90,8 +90,8 @@ struct BuiltinDialect : public Dialect {
UnknownLoc>(); UnknownLoc>();
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType, addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType, MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
UnrankedTensorType, VectorType>(); RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
// TODO: These operations should be moved to a different dialect when they // TODO: These operations should be moved to a different dialect when they
// have been fully decoupled from the core. // have been fully decoupled from the core.

View File

@ -390,6 +390,37 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
//===----------------------------------------------------------------------===//
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
unsigned memorySpace) {
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
elementType, memorySpace);
}
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
unsigned memorySpace,
Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::UnrankedMemRef, elementType,
memorySpace);
}
unsigned UnrankedMemRefType::getMemorySpace() const {
return getImpl()->memorySpace;
}
LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
return emitOptionalError(*loc, "invalid memref element type");
return success();
}
/// Given MemRef `sizes` that are either static or dynamic, returns the /// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become /// once a dynamic dimension is encountered, all canonical strides become

View File

@ -119,8 +119,8 @@ struct FunctionTypeStorage : public TypeStorage {
/// Shaped Type Storage. /// Shaped Type Storage.
struct ShapedTypeStorage : public TypeStorage { struct ShapedTypeStorage : public TypeStorage {
ShapedTypeStorage(Type elementType, unsigned subclassData = 0) ShapedTypeStorage(Type elementTy, unsigned subclassData = 0)
: TypeStorage(subclassData), elementType(elementType) {} : TypeStorage(subclassData), elementType(elementTy) {}
/// The hash key used for uniquing. /// The hash key used for uniquing.
using KeyTy = Type; using KeyTy = Type;
@ -252,6 +252,31 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
const unsigned memorySpace; const unsigned memorySpace;
}; };
/// Unranked MemRef is a MemRef with unknown rank.
/// Only element type and memory space are known
struct UnrankedMemRefTypeStorage : public ShapedTypeStorage {
UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace)
: ShapedTypeStorage(elementTy), memorySpace(memorySpace) {}
/// The hash key used for uniquing.
using KeyTy = std::tuple<Type, unsigned>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, memorySpace);
}
/// Construction.
static UnrankedMemRefTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
// Initialize the memory using placement new.
return new (allocator.allocate<UnrankedMemRefTypeStorage>())
UnrankedMemRefTypeStorage(std::get<0>(key), std::get<1>(key));
}
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
};
/// Complex Type Storage. /// Complex Type Storage.
struct ComplexTypeStorage : public TypeStorage { struct ComplexTypeStorage : public TypeStorage {
ComplexTypeStorage(Type elementType) : elementType(elementType) {} ComplexTypeStorage(Type elementType) : elementType(elementType) {}

View File

@ -1054,8 +1054,13 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
/// Parse a memref type. /// Parse a memref type.
/// ///
/// memref-type ::= `memref` `<` dimension-list-ranked type /// memref-type ::= ranked-memref-type | unranked-memref-type
/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>` ///
/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
/// (`,` semi-affine-map-composition)? (`,`
/// memory-space)? `>`
///
/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
/// ///
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
/// memory-space ::= integer-literal /* | TODO: address-space-id */ /// memory-space ::= integer-literal /* | TODO: address-space-id */
@ -1066,9 +1071,20 @@ Type Parser::parseMemRefType() {
if (parseToken(Token::less, "expected '<' in memref type")) if (parseToken(Token::less, "expected '<' in memref type"))
return nullptr; return nullptr;
bool isUnranked;
SmallVector<int64_t, 4> dimensions; SmallVector<int64_t, 4> dimensions;
if (consumeIf(Token::star)) {
// This is an unranked memref type.
isUnranked = true;
if (parseXInDimensionList())
return nullptr;
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions)) if (parseDimensionListRanked(dimensions))
return nullptr; return nullptr;
}
// Parse the element type. // Parse the element type.
auto typeLoc = getToken().getLoc(); auto typeLoc = getToken().getLoc();
@ -1093,6 +1109,8 @@ Type Parser::parseMemRefType() {
consumeToken(Token::integer); consumeToken(Token::integer);
parsedMemorySpace = true; parsedMemorySpace = true;
} else { } else {
if (isUnranked)
return emitError("cannot have affine map for unranked memref type");
if (parsedMemorySpace) if (parsedMemorySpace)
return emitError("expected memory space to be last in memref type"); return emitError("expected memory space to be last in memref type");
if (getToken().is(Token::kw_offset)) { if (getToken().is(Token::kw_offset)) {
@ -1131,6 +1149,10 @@ Type Parser::parseMemRefType() {
return nullptr; return nullptr;
} }
if (isUnranked)
return UnrankedMemRefType::getChecked(elementType, memorySpace,
getEncodedSourceLocation(typeLoc));
return MemRefType::getChecked(dimensions, elementType, affineMapComposition, return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
memorySpace, getEncodedSourceLocation(typeLoc)); memorySpace, getEncodedSourceLocation(typeLoc));
} }

View File

@ -371,6 +371,30 @@ func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
return return
} }
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
// CHECK-DAG: llvm.insertvalue %[[p2]], %{{.*}}[1] : !llvm<"{ i64, i8* }">
%0 = memref_cast %arg : memref<42x2x?xf32> to memref<*xf32>
return
}
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
return
}
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { // CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">

View File

@ -565,6 +565,11 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64
// CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]> // CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]>
%3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]> %3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>
// CHECK: memref_cast %{{.*}} : memref<4xf32> to memref<*xf32>
%4 = memref_cast %1 : memref<4xf32> to memref<*xf32>
// CHECK: memref_cast %{{.*}} : memref<*xf32> to memref<4xf32>
%5 = memref_cast %4 : memref<*xf32> to memref<4xf32>
return return
} }

View File

@ -978,3 +978,34 @@ func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16,
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]>
return return
} }
// -----
// incompatible element types
func @invalid_memref_cast() {
%0 = alloc() : memref<2x5xf32, 0>
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}}
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xi32>
return
}
// -----
// incompatible memory space
func @invalid_memref_cast() {
%0 = alloc() : memref<2x5xf32, 0>
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32>' are cast incompatible}}
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1>
return
}
// -----
// unranked to unranked
func @invalid_memref_cast() {
%0 = alloc() : memref<2x5xf32, 0>
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
// expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
return
}

View File

@ -78,6 +78,12 @@ template <typename T> struct StridedMemRefType<T, 0> {
int64_t offset; int64_t offset;
}; };
// Unranked MemRef
struct UnrankedMemRefType {
int64_t rank;
void *descriptor;
};
template <typename StreamType, typename T, int N> template <typename StreamType, typename T, int N>
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) { void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
static_assert(N > 0, "Expected N > 0"); static_assert(N > 0, "Expected N > 0");
@ -97,6 +103,15 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
<< " offset = " << V.offset; << " offset = " << V.offset;
} }
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_f32(UnrankedMemRefType *M);
template <typename StreamType>
void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) {
os << "Unranked Memref rank = " << V.rank << " "
<< "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
}
extern "C" MLIR_RUNNER_UTILS_EXPORT void extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_0d_f32(StridedMemRefType<float, 0> *M); print_memref_0d_f32(StridedMemRefType<float, 0> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void extern "C" MLIR_RUNNER_UTILS_EXPORT void

View File

@ -148,15 +148,41 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
std::cout << std::endl; std::cout << std::endl;
} }
template <typename T> void printZeroDMemRef(StridedMemRefType<T, 0> &M) { template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
std::cout << "\nMemref base@ = " << M.data << " rank = " << 0 std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
<< " offset = " << M.offset << " data = ["; << " offset = " << M.offset << " data = [";
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset); MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
std::cout << "]" << std::endl; std::cout << "]" << std::endl;
} }
extern "C" void
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
printMemRef(*M);
}
extern "C" void print_memref_f32(UnrankedMemRefType *M) {
printUnrankedMemRefMetaData(std::cout, *M);
int rank = M->rank;
void *ptr = M->descriptor;
#define MEMREF_CASE(RANK) \
case RANK: \
printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr))); \
break
switch (rank) {
MEMREF_CASE(0);
MEMREF_CASE(1);
MEMREF_CASE(2);
MEMREF_CASE(3);
MEMREF_CASE(4);
default:
assert(0 && "Unsupported rank to print");
}
}
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) { extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
printZeroDMemRef(*M); printMemRef(*M);
} }
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) { extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
printMemRef(*M); printMemRef(*M);
@ -170,8 +196,3 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) { extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
printMemRef(*M); printMemRef(*M);
} }
extern "C" void
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
printMemRef(*M);
}

View File

@ -0,0 +1,43 @@
// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [10, 10, 10]
//
// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [5, 5, 5]
//
// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [2, 2, 2]
func @main() -> () {
%A = alloc() : memref<10x3xf32, 0>
%f2 = constant 2.00000e+00 : f32
%f5 = constant 5.00000e+00 : f32
%f10 = constant 10.00000e+00 : f32
%V = memref_cast %A : memref<10x3xf32, 0> to memref<?x?xf32>
linalg.fill(%V, %f10) : memref<?x?xf32, 0>, f32
%U = memref_cast %A : memref<10x3xf32, 0> to memref<*xf32>
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
%V2 = memref_cast %U : memref<*xf32> to memref<?x?xf32>
linalg.fill(%V2, %f5) : memref<?x?xf32, 0>, f32
%U2 = memref_cast %V2 : memref<?x?xf32, 0> to memref<*xf32>
call @print_memref_f32(%U2) : (memref<*xf32>) -> ()
%V3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
%V4 = memref_cast %V3 : memref<*xf32> to memref<?x?xf32>
linalg.fill(%V4, %f2) : memref<?x?xf32, 0>, f32
%U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
dealloc %A : memref<10x3xf32, 0>
return
}
func @print_memref_f32(memref<*xf32>)

View File

@ -7,7 +7,8 @@ func @print_0d() {
%f = constant 2.00000e+00 : f32 %f = constant 2.00000e+00 : f32
%A = alloc() : memref<f32> %A = alloc() : memref<f32>
store %f, %A[]: memref<f32> store %f, %A[]: memref<f32>
call @print_memref_0d_f32(%A): (memref<f32>) -> () %U = memref_cast %A : memref<f32> to memref<*xf32>
call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<f32> dealloc %A : memref<f32>
return return
} }
@ -18,7 +19,8 @@ func @print_1d() {
%A = alloc() : memref<16xf32> %A = alloc() : memref<16xf32>
%B = memref_cast %A: memref<16xf32> to memref<?xf32> %B = memref_cast %A: memref<16xf32> to memref<?xf32>
linalg.fill(%B, %f) : memref<?xf32>, f32 linalg.fill(%B, %f) : memref<?xf32>, f32
call @print_memref_1d_f32(%B): (memref<?xf32>) -> () %U = memref_cast %B : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<16xf32> dealloc %A : memref<16xf32>
return return
} }
@ -34,8 +36,8 @@ func @print_3d() {
%c2 = constant 2 : index %c2 = constant 2 : index
store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32> store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>
%U = memref_cast %B : memref<?x?x?xf32> to memref<*xf32>
call @print_memref_3d_f32(%B): (memref<?x?x?xf32>) -> () call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<3x4x5xf32> dealloc %A : memref<3x4x5xf32>
return return
} }
@ -46,10 +48,7 @@ func @print_3d() {
// PRINT-3D-NEXT: 2, 2, 4, 2, 2 // PRINT-3D-NEXT: 2, 2, 4, 2, 2
// PRINT-3D-NEXT: 2, 2, 2, 2, 2 // PRINT-3D-NEXT: 2, 2, 2, 2, 2
func @print_memref_0d_f32(memref<f32>) func @print_memref_f32(memref<*xf32>)
func @print_memref_1d_f32(memref<?xf32>)
func @print_memref_3d_f32(memref<?x?x?xf32>)
!vector_type_C = type vector<4x4xf32> !vector_type_C = type vector<4x4xf32>
!matrix_type_CC = type memref<1x1x!vector_type_C> !matrix_type_CC = type memref<1x1x!vector_type_C>

View File

@ -22,9 +22,10 @@ func @main() {
store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32> store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32>
gpu.return gpu.return
} }
call @print_memref_3d_f32(%dst) : (memref<?x?x?xf32>) -> () %U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
return return
} }
func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>) func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
func @print_memref_3d_f32(%ptr : memref<?x?x?xf32>) func @print_memref_f32(%ptr : memref<*xf32>)

View File

@ -20,9 +20,10 @@ func @main() {
store %res, %kernel_dst[%tx] : memref<?xf32> store %res, %kernel_dst[%tx] : memref<?xf32>
gpu.return gpu.return
} }
call @print_memref_1d_f32(%dst) : (memref<?xf32>) -> () %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
return return
} }
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>) func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
func @print_memref_1d_f32(memref<?xf32>) func @print_memref_f32(memref<*xf32>)