forked from OSchip/llvm-project
Add UnrankedMemRef Type
Closes tensorflow/mlir#261 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/261 from nmostafa:nmostafa/unranked 96b6e918f6ed64496f7573b2db33c0b02658ca45 PiperOrigin-RevId: 284037040
This commit is contained in:
parent
e67acfa468
commit
daff60cd68
|
@ -90,6 +90,23 @@ memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, float*, i64, [5 x i64], [5
|
|||
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }">
|
||||
```
|
||||
|
||||
If the rank of the memref is unknown at compile time, the Memref is converted to
|
||||
an unranked descriptor that contains: 1. a 64-bit integer representing the
|
||||
dynamic rank of the memref, followed by 2. a pointer to a ranked memref
|
||||
descriptor with the contents listed above.
|
||||
|
||||
Dynamic ranked memrefs should be used only to pass arguments to external library
|
||||
calls that expect a unified memref type. The called functions can parse any
|
||||
unranked memref descriptor by reading the rank and parsing the enclosed ranked
|
||||
descriptor pointer.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// unranked descriptor
|
||||
memref<*xf32> -> !llvm.type<"{i64, i8*}">
|
||||
```
|
||||
|
||||
### Function Types
|
||||
|
||||
Function types get converted to LLVM function types. The arguments are converted
|
||||
|
|
|
@ -912,12 +912,21 @@ Examples:
|
|||
|
||||
// Convert to a type with more known dimensions.
|
||||
%4 = memref_cast %3 : memref<?x?xf32> to memref<4x?xf32>
|
||||
|
||||
// Convert to a type with unknown rank.
|
||||
%5 = memref_cast %3 : memref<?x?xf32> to memref<*xf32>
|
||||
|
||||
// Convert to a type with static rank.
|
||||
%6 = memref_cast %5 : memref<*xf32> to memref<?x?xf32>
|
||||
```
|
||||
|
||||
Convert a memref from one type to an equivalent type without changing any data
|
||||
elements. The source and destination types must both be memref types with the
|
||||
same element type, same mappings, same address space, and same rank. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
elements. The types are equivalent if 1. they both have the same static rank,
|
||||
same element type, same mappings, same address space. The operation is invalid
|
||||
if converting to a mismatching constant dimension, or 2. exactly one of the
|
||||
operands have an unknown rank, and they both have the same element type and same
|
||||
address space. The operation is invalid if both operands are of dynamic rank or
|
||||
if converting to a mismatching static rank.
|
||||
|
||||
### 'mulf' operation
|
||||
|
||||
|
|
|
@ -760,9 +760,15 @@ TODO: Need to decide on a representation for quantized integers
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
|
||||
(`,` layout-specification)? |
|
||||
(`,` memory-space)? `>`
|
||||
|
||||
memref-type ::= ranked-memref-type | unranked-memref-type
|
||||
|
||||
ranked-memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
|
||||
(`,` layout-specification)? |
|
||||
(`,` memory-space)? `>`
|
||||
|
||||
unranked-memref-type ::= `memref` `<*x` tensor-memref-element-type
|
||||
(`,` memory-space)? `>`
|
||||
|
||||
stride-list ::= `[` (dimension (`,` dimension)*)? `]`
|
||||
strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
|
||||
|
@ -774,9 +780,48 @@ A `memref` type is a reference to a region of memory (similar to a buffer
|
|||
pointer, but more powerful). The buffer pointed to by a memref can be allocated,
|
||||
aliased and deallocated. A memref can be used to read and write data from/to the
|
||||
memory region which it references. Memref types use the same shape specifier as
|
||||
tensor types, but do not allow unknown rank. Note that `memref<f32>`, `memref<0
|
||||
x f32>`, `memref<1 x 0 x f32>`, and `memref<0 x 1 x f32>` are all different
|
||||
types.
|
||||
tensor types. Note that `memref<f32>`, `memref<0 x f32>`, `memref<1 x 0 x f32>`,
|
||||
and `memref<0 x 1 x f32>` are all different types.
|
||||
|
||||
A `memref` is allowed to have an unknown rank (e.g. `memref<*xf32>`). The
|
||||
purpose of unranked memrefs is to allow external library functions to receive
|
||||
memref arguments of any rank without versioning the functions based on the rank.
|
||||
Other uses of this type are disallowed or will have undefined behavior.
|
||||
|
||||
##### Codegen of Unranked Memref
|
||||
|
||||
Using unranked memref in codegen besides the case mentioned above is highly
|
||||
discouraged. Codegen is concerned with generating loop nests and specialized
|
||||
instructions for high-performance, unranked memref is concerned with hiding the
|
||||
rank and thus, the number of enclosing loops required to iterate over the data.
|
||||
However, if there is a need to code-gen unranked memref, one possible path is to
|
||||
cast into a static ranked type based on the dynamic rank. Another possible path
|
||||
is to emit a single while loop conditioned on a linear index and perform
|
||||
delinearization of the linear index to a dynamic array containing the (unranked)
|
||||
indices. While this is possible, it is expected to not be a good idea to perform
|
||||
this during codegen as the cost of the translations is expected to be
|
||||
prohibitive and optimizations at this level are not expected to be worthwhile.
|
||||
If expressiveness is the main concern, irrespective of performance, passing
|
||||
unranked memrefs to an external C++ library and implementing rank-agnostic logic
|
||||
there is expected to be significantly simpler.
|
||||
|
||||
Unranked memrefs may provide expressiveness gains in the future and help bridge
|
||||
the gap with unranked tensors. Unranked memrefs will not be expected to be
|
||||
exposed to codegen but one may query the rank of an unranked memref (a special
|
||||
op will be needed for this purpose) and perform a switch and cast to a ranked
|
||||
memref as a prerequisite to codegen.
|
||||
|
||||
Example ```mlir {.mlir} // With static ranks, we need a function for each
|
||||
possible argument type %A = alloc() : memref<16x32xf32> %B = alloc() :
|
||||
memref<16x32x64xf32> call @helper_2D(%A) : (memref<16x32xf32>)->() call
|
||||
@helper_3D(%B) : (memref<16x32x64xf32>)->()
|
||||
|
||||
// With unknown rank, the functions can be unified under one unranked type %A =
|
||||
alloc() : memref<16x32xf32> %B = alloc() : memref<16x32x64xf32> // Remove rank
|
||||
info %A_u = memref_cast %A : memref<16x32xf32> -> memref<*xf32> %B_u =
|
||||
memref_cast %B : memref<16x32x64xf32> -> memref<*xf32> // call same function
|
||||
with dynamic ranks call @helper(%A_u) : (memref<*xf32>)->() call @helper(%B_u) :
|
||||
(memref<*xf32>)->() ```
|
||||
|
||||
The core syntax and representation of a layout specification is a
|
||||
[semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic
|
||||
|
|
|
@ -34,6 +34,9 @@ class Type;
|
|||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class UnrankedMemRefType;
|
||||
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
class LLVMType;
|
||||
|
@ -116,6 +119,10 @@ private:
|
|||
// 2. as many index types as memref has dynamic dimensions.
|
||||
Type convertMemRefType(MemRefType type);
|
||||
|
||||
// Convert an unranked memref type to an LLVM type that captures the
|
||||
// runtime rank and a pointer to the static ranked memref desc
|
||||
Type convertUnrankedMemRefType(UnrankedMemRefType type);
|
||||
|
||||
// Convert a 1D vector type into an LLVM vector type.
|
||||
Type convertVectorType(VectorType type);
|
||||
|
||||
|
@ -127,10 +134,34 @@ private:
|
|||
LLVM::LLVMType unwrap(Type type);
|
||||
};
|
||||
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
/// values to a struct.
|
||||
class StructBuilder {
|
||||
public:
|
||||
/// Construct a helper for the given value.
|
||||
explicit StructBuilder(Value *v);
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
static StructBuilder undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType);
|
||||
|
||||
/*implicit*/ operator Value *() { return value; }
|
||||
|
||||
protected:
|
||||
// LLVM value
|
||||
Value *value;
|
||||
// Cached struct type.
|
||||
Type structType;
|
||||
|
||||
protected:
|
||||
/// Builds IR to extract a value from the struct at position pos
|
||||
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
|
||||
/// Builds IR to set a value in the struct at position pos
|
||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
|
||||
};
|
||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||
/// The Value may be null, in which case none of the operations are valid.
|
||||
class MemRefDescriptor {
|
||||
class MemRefDescriptor : public StructBuilder {
|
||||
public:
|
||||
/// Construct a helper for the given descriptor value.
|
||||
explicit MemRefDescriptor(Value *descriptor);
|
||||
|
@ -169,22 +200,28 @@ public:
|
|||
/// Returns the (LLVM) type this descriptor points to.
|
||||
LLVM::LLVMType getElementType();
|
||||
|
||||
/*implicit*/ operator Value *() { return value; }
|
||||
|
||||
private:
|
||||
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
|
||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
|
||||
|
||||
// Cached descriptor type.
|
||||
Type structType;
|
||||
|
||||
// Cached index type.
|
||||
Type indexType;
|
||||
|
||||
// Actual descriptor.
|
||||
Value *value;
|
||||
};
|
||||
|
||||
class UnrankedMemRefDescriptor : public StructBuilder {
|
||||
public:
|
||||
/// Construct a helper for the given descriptor value.
|
||||
explicit UnrankedMemRefDescriptor(Value *descriptor);
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType);
|
||||
|
||||
/// Builds IR extracting the rank from the descriptor
|
||||
Value *rank(OpBuilder &builder, Location loc);
|
||||
/// Builds IR setting the rank in the descriptor
|
||||
void setRank(OpBuilder &builder, Location loc, Value *value);
|
||||
/// Builds IR extracting ranked memref descriptor ptr
|
||||
Value *memRefDescPtr(OpBuilder &builder, Location loc);
|
||||
/// Builds IR setting ranked memref descriptor ptr
|
||||
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value *value);
|
||||
};
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||
/// conversion patterns with an access to the containing LLVMLowering for the
|
||||
/// purpose of type conversions.
|
||||
|
|
|
@ -842,7 +842,8 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
let description = [{
|
||||
The "memref_cast" operation converts a memref from one type to an equivalent
|
||||
type with a compatible shape. The source and destination types are
|
||||
when both are memref types with the same element type, affine mappings,
|
||||
compatible if:
|
||||
a. both are ranked memref types with the same element type, affine mappings,
|
||||
address space, and rank but where the individual dimensions may add or
|
||||
remove constant dimensions from the memref type.
|
||||
|
||||
|
@ -850,6 +851,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
acts as an assertion that fails at runtime of the dynamic dimensions
|
||||
disagree with resultant destination size.
|
||||
|
||||
Example:
|
||||
Assert that the input dynamic shape matches the destination static shape.
|
||||
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
|
||||
Erase static shape information, replacing it with dynamic information.
|
||||
|
@ -864,10 +866,20 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
dynamic information.
|
||||
%5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to
|
||||
memref<12x4xf32, offset:?, strides: [?, ?]>
|
||||
|
||||
b. either or both memref types are unranked with the same element type, and
|
||||
address space.
|
||||
|
||||
Example:
|
||||
Cast to concrete shape.
|
||||
%4 = memref_cast %1 : memref<*xf32> to memref<4x?xf32>
|
||||
|
||||
Erase rank information.
|
||||
%5 = memref_cast %1 : memref<4x?xf32> to memref<*xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$source);
|
||||
let results = (outs AnyMemRef);
|
||||
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
|
@ -875,7 +887,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
Type getType() { return getResult()->getType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -221,6 +221,9 @@ def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
|
|||
// Whether a type is a MemRefType.
|
||||
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
||||
|
||||
// Whether a type is an IsUnrankedMemRefType
|
||||
def IsUnrankedMemRefTypePred : CPred<"$_self.isa<UnrankedMemRefType>()">;
|
||||
|
||||
// Whether a type is a ShapedType.
|
||||
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
|
||||
|
||||
|
@ -486,6 +489,10 @@ class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
|
|||
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
|
||||
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
|
||||
|
||||
// Unranked Memref type
|
||||
def AnyUnrankedMemRef :
|
||||
ShapedContainerType<[AnyType],
|
||||
IsUnrankedMemRefTypePred, "unranked.memref">;
|
||||
// Memref type.
|
||||
|
||||
// Memrefs are blocks of data with fixed type and rank.
|
||||
|
@ -494,6 +501,8 @@ class MemRefOf<list<Type> allowedTypes> :
|
|||
|
||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||
|
||||
def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
|
||||
|
||||
// Memref declarations handle any memref, independent of rank, size, (static or
|
||||
// dynamic), layout, or memory space.
|
||||
def I1MemRef : MemRefOf<[I1]>;
|
||||
|
|
|
@ -40,6 +40,7 @@ struct VectorTypeStorage;
|
|||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
struct UnrankedMemRefTypeStorage;
|
||||
struct ComplexTypeStorage;
|
||||
struct TupleTypeStorage;
|
||||
|
||||
|
@ -64,6 +65,7 @@ enum Kind {
|
|||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
UnrankedMemRef,
|
||||
Complex,
|
||||
Tuple,
|
||||
None,
|
||||
|
@ -243,6 +245,7 @@ public:
|
|||
return type.getKind() == StandardTypes::Vector ||
|
||||
type.getKind() == StandardTypes::RankedTensor ||
|
||||
type.getKind() == StandardTypes::UnrankedTensor ||
|
||||
type.getKind() == StandardTypes::UnrankedMemRef ||
|
||||
type.getKind() == StandardTypes::MemRef;
|
||||
}
|
||||
|
||||
|
@ -370,12 +373,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Base MemRef for Ranked and Unranked variants
|
||||
class BaseMemRefType : public ShapedType {
|
||||
public:
|
||||
using ShapedType::ShapedType;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() == StandardTypes::MemRef ||
|
||||
type.getKind() == StandardTypes::UnrankedMemRef;
|
||||
}
|
||||
};
|
||||
|
||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||
/// number of dimensions. Each shape element can be a non-negative integer or
|
||||
/// unknown (represented by any negative integer). MemRef types also have an
|
||||
/// affine map composition, represented as an array AffineMap pointers.
|
||||
class MemRefType
|
||||
: public Type::TypeBase<MemRefType, ShapedType, detail::MemRefTypeStorage> {
|
||||
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
|
||||
detail::MemRefTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
|
@ -426,6 +441,40 @@ private:
|
|||
using Base::getImpl;
|
||||
};
|
||||
|
||||
/// Unranked MemRef type represent multi-dimensional MemRefs that
|
||||
/// have an unknown rank.
|
||||
class UnrankedMemRefType
|
||||
: public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
|
||||
detail::UnrankedMemRefTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create a new UnrankedMemRefType of the provided element
|
||||
/// type and memory space
|
||||
static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
|
||||
|
||||
/// Get or create a new UnrankedMemRefType of the provided element
|
||||
/// type and memory space declared at the given, potentially unknown,
|
||||
/// location. If the UnrankedMemRefType defined by the arguments would be
|
||||
/// ill-formed, emit errors and return a nullptr-wrapping type.
|
||||
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
|
||||
Location location);
|
||||
|
||||
/// Verify the construction of a unranked memref type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Type elementType,
|
||||
unsigned memorySpace);
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
||||
|
||||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const;
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardTypes::UnrankedMemRef;
|
||||
}
|
||||
};
|
||||
|
||||
/// Tuple types represent a collection of other types. Note: This type merely
|
||||
/// provides a common mechanism for representing tuples in MLIR. It is up to
|
||||
/// dialect authors to provides operations for manipulating them, e.g.
|
||||
|
|
|
@ -193,6 +193,22 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
|||
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
|
||||
}
|
||||
|
||||
// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
|
||||
// contains:
|
||||
// 1. int64_t rank, the dynamic rank of this MemRef
|
||||
// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
|
||||
// stack allocated (alloca) copy of a MemRef descriptor that got casted to
|
||||
// be unranked.
|
||||
|
||||
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
|
||||
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
|
||||
|
||||
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
||||
auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
|
||||
}
|
||||
|
||||
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
|
||||
// n > 1.
|
||||
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
|
||||
|
@ -221,6 +237,8 @@ Type LLVMTypeConverter::convertStandardType(Type type) {
|
|||
return convertIndexType(indexType);
|
||||
if (auto memRefType = type.dyn_cast<MemRefType>())
|
||||
return convertMemRefType(memRefType);
|
||||
if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||
return convertUnrankedMemRefType(memRefType);
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return convertVectorType(vectorType);
|
||||
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
|
||||
|
@ -245,22 +263,42 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
|||
PatternBenefit benefit)
|
||||
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
||||
|
||||
/*============================================================================*/
|
||||
/* StructBuilder implementation */
|
||||
/*============================================================================*/
|
||||
StructBuilder::StructBuilder(Value *v) : value(v) {
|
||||
assert(value != nullptr && "value cannot be null");
|
||||
structType = value->getType().cast<LLVM::LLVMType>();
|
||||
}
|
||||
|
||||
Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc,
|
||||
unsigned pos) {
|
||||
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
|
||||
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *ptr) {
|
||||
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
/*============================================================================*/
|
||||
/* MemRefDescriptor implementation */
|
||||
/*============================================================================*/
|
||||
|
||||
/// Construct a helper for the given descriptor value.
|
||||
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
|
||||
if (value) {
|
||||
structType = value->getType().cast<LLVM::LLVMType>();
|
||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
kOffsetPosInMemRefDescriptor);
|
||||
}
|
||||
MemRefDescriptor::MemRefDescriptor(Value *descriptor)
|
||||
: StructBuilder(descriptor) {
|
||||
assert(value != nullptr && "value cannot be null");
|
||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
kOffsetPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
|
||||
Type descriptorType) {
|
||||
|
||||
Value *descriptor =
|
||||
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
||||
return MemRefDescriptor(descriptor);
|
||||
|
@ -334,24 +372,42 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
|
|||
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
|
||||
}
|
||||
|
||||
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
|
||||
unsigned pos) {
|
||||
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
|
||||
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
||||
Value *ptr) {
|
||||
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||
builder.getI64ArrayAttr(pos));
|
||||
}
|
||||
|
||||
LLVM::LLVMType MemRefDescriptor::getElementType() {
|
||||
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
kAlignedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* UnrankedMemRefDescriptor implementation */
|
||||
/*============================================================================*/
|
||||
|
||||
/// Construct a helper for the given descriptor value.
|
||||
UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor)
|
||||
: StructBuilder(descriptor) {}
|
||||
|
||||
/// Builds IR creating an `undef` value of the descriptor type.
|
||||
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
|
||||
Location loc,
|
||||
Type descriptorType) {
|
||||
Value *descriptor =
|
||||
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
||||
return UnrankedMemRefDescriptor(descriptor);
|
||||
}
|
||||
Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
|
||||
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
|
||||
}
|
||||
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
|
||||
Value *v) {
|
||||
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
|
||||
}
|
||||
Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
|
||||
Location loc) {
|
||||
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
|
||||
}
|
||||
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
|
||||
Location loc, Value *v) {
|
||||
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
|
||||
}
|
||||
namespace {
|
||||
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
||||
// provided as template argument. Carries a reference to the LLVM dialect in
|
||||
|
@ -432,7 +488,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
|||
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
|
||||
if (!converted)
|
||||
return matchFailure();
|
||||
if (t.isa<MemRefType>()) {
|
||||
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) {
|
||||
converted = converted.getPointerTo();
|
||||
promotedArgIndices.push_back(en.index());
|
||||
}
|
||||
|
@ -983,6 +1039,14 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
Type packedResult;
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
||||
|
||||
for (Type resType : resultTypes) {
|
||||
assert(!resType.isa<UnrankedMemRefType>() &&
|
||||
"Returning unranked memref is not supported. Pass result as an"
|
||||
"argument instead.");
|
||||
(void)resType;
|
||||
}
|
||||
|
||||
if (numResults != 0) {
|
||||
if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
|
||||
return this->matchFailure();
|
||||
|
@ -1076,11 +1140,26 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
MemRefType sourceType =
|
||||
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||
MemRefType targetType = memRefCastOp.getType();
|
||||
return (isSupportedMemRefType(targetType) &&
|
||||
isSupportedMemRefType(sourceType))
|
||||
Type srcType = memRefCastOp.getOperand()->getType();
|
||||
Type dstType = memRefCastOp.getType();
|
||||
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
|
||||
MemRefType sourceType =
|
||||
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
|
||||
return (isSupportedMemRefType(targetType) &&
|
||||
isSupportedMemRefType(sourceType))
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
}
|
||||
|
||||
// At least one of the operands is unranked type
|
||||
assert(srcType.isa<UnrankedMemRefType>() ||
|
||||
dstType.isa<UnrankedMemRefType>());
|
||||
|
||||
// Unranked to unranked cast is disallowed
|
||||
return !(srcType.isa<UnrankedMemRefType>() &&
|
||||
dstType.isa<UnrankedMemRefType>())
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
}
|
||||
|
@ -1089,12 +1168,65 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
OperandAdaptor<MemRefCastOp> transformed(operands);
|
||||
// memref_cast is defined for source and destination memref types with the
|
||||
// same element type, same mappings, same address space and same rank.
|
||||
// Therefore a simple bitcast suffices. If not it is undefined behavior.
|
||||
|
||||
auto srcType = memRefCastOp.getOperand()->getType();
|
||||
auto dstType = memRefCastOp.getType();
|
||||
auto targetStructType = lowering.convertType(memRefCastOp.getType());
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
|
||||
transformed.source());
|
||||
auto loc = op->getLoc();
|
||||
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
|
||||
// memref_cast is defined for source and destination memref types with the
|
||||
// same element type, same mappings, same address space and same rank.
|
||||
// Therefore a simple bitcast suffices. If not it is undefined behavior.
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
|
||||
transformed.source());
|
||||
} else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
|
||||
// Casting ranked to unranked memref type
|
||||
// Set the rank in the destination from the memref type
|
||||
// Allocate space on the stack and copy the src memref decsriptor
|
||||
// Set the ptr in the destination to the stack space
|
||||
auto srcMemRefType = srcType.cast<MemRefType>();
|
||||
int64_t rank = srcMemRefType.getRank();
|
||||
// ptr = AllocaOp sizeof(MemRefDescriptor)
|
||||
auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
|
||||
rewriter);
|
||||
// voidptr = BitCastOp srcType* to void*
|
||||
auto voidPtr =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
|
||||
.getResult();
|
||||
// rank = ConstantOp srcRank
|
||||
auto rankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, lowering.convertType(rewriter.getIntegerType(64)),
|
||||
rewriter.getI64IntegerAttr(rank));
|
||||
// undef = UndefOp
|
||||
UnrankedMemRefDescriptor memRefDesc =
|
||||
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
|
||||
// d1 = InsertValueOp undef, rank, 0
|
||||
memRefDesc.setRank(rewriter, loc, rankVal);
|
||||
// d2 = InsertValueOp d1, voidptr, 1
|
||||
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
|
||||
rewriter.replaceOp(op, (Value *)memRefDesc);
|
||||
|
||||
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
|
||||
// Casting from unranked type to ranked.
|
||||
// The operation is assumed to be doing a correct cast. If the destination
|
||||
// type mismatches the unranked the type, it is undefined behavior.
|
||||
UnrankedMemRefDescriptor memRefDesc(transformed.source());
|
||||
// ptr = ExtractValueOp src, 1
|
||||
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
|
||||
// castPtr = BitCastOp i8* to structTy*
|
||||
auto castPtr =
|
||||
rewriter
|
||||
.create<LLVM::BitcastOp>(
|
||||
loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
|
||||
ptr)
|
||||
.getResult();
|
||||
// struct = LoadOp castPtr
|
||||
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
|
||||
rewriter.replaceOp(op, loadOp.getResult());
|
||||
} else {
|
||||
llvm_unreachable("Unsuppored unranked memref to unranked memref cast");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1896,7 +2028,8 @@ SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
|
|||
for (auto it : llvm::zip(opOperands, operands)) {
|
||||
auto *operand = std::get<0>(it);
|
||||
auto *llvmOperand = std::get<1>(it);
|
||||
if (!operand->getType().isa<MemRefType>()) {
|
||||
if (!operand->getType().isa<MemRefType>() &&
|
||||
!operand->getType().isa<UnrankedMemRefType>()) {
|
||||
promotedOperands.push_back(operand);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1769,46 +1769,70 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
|||
auto aT = a.dyn_cast<MemRefType>();
|
||||
auto bT = b.dyn_cast<MemRefType>();
|
||||
|
||||
if (!aT || !bT)
|
||||
return false;
|
||||
if (aT.getElementType() != bT.getElementType())
|
||||
return false;
|
||||
if (aT.getAffineMaps() != bT.getAffineMaps()) {
|
||||
int64_t aOffset, bOffset;
|
||||
SmallVector<int64_t, 4> aStrides, bStrides;
|
||||
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
|
||||
failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
|
||||
aStrides.size() != bStrides.size())
|
||||
return false;
|
||||
auto uaT = a.dyn_cast<UnrankedMemRefType>();
|
||||
auto ubT = b.dyn_cast<UnrankedMemRefType>();
|
||||
|
||||
// Strides along a dimension/offset are compatible if the value in the
|
||||
// source memref is static and the value in the target memref is the
|
||||
// same. They are also compatible if either one is dynamic (see description
|
||||
// of MemRefCastOp for details).
|
||||
auto checkCompatible = [](int64_t a, int64_t b) {
|
||||
return (a == MemRefType::getDynamicStrideOrOffset() ||
|
||||
b == MemRefType::getDynamicStrideOrOffset() || a == b);
|
||||
};
|
||||
if (!checkCompatible(aOffset, bOffset))
|
||||
if (aT && bT) {
|
||||
if (aT.getElementType() != bT.getElementType())
|
||||
return false;
|
||||
for (auto aStride : enumerate(aStrides))
|
||||
if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
|
||||
if (aT.getAffineMaps() != bT.getAffineMaps()) {
|
||||
int64_t aOffset, bOffset;
|
||||
SmallVector<int64_t, 4> aStrides, bStrides;
|
||||
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
|
||||
failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
|
||||
aStrides.size() != bStrides.size())
|
||||
return false;
|
||||
}
|
||||
if (aT.getMemorySpace() != bT.getMemorySpace())
|
||||
return false;
|
||||
|
||||
// They must have the same rank, and any specified dimensions must match.
|
||||
if (aT.getRank() != bT.getRank())
|
||||
return false;
|
||||
|
||||
for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
|
||||
int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
|
||||
if (aDim != -1 && bDim != -1 && aDim != bDim)
|
||||
// Strides along a dimension/offset are compatible if the value in the
|
||||
// source memref is static and the value in the target memref is the
|
||||
// same. They are also compatible if either one is dynamic (see
|
||||
// description of MemRefCastOp for details).
|
||||
auto checkCompatible = [](int64_t a, int64_t b) {
|
||||
return (a == MemRefType::getDynamicStrideOrOffset() ||
|
||||
b == MemRefType::getDynamicStrideOrOffset() || a == b);
|
||||
};
|
||||
if (!checkCompatible(aOffset, bOffset))
|
||||
return false;
|
||||
for (auto aStride : enumerate(aStrides))
|
||||
if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
|
||||
return false;
|
||||
}
|
||||
if (aT.getMemorySpace() != bT.getMemorySpace())
|
||||
return false;
|
||||
|
||||
// They must have the same rank, and any specified dimensions must match.
|
||||
if (aT.getRank() != bT.getRank())
|
||||
return false;
|
||||
|
||||
for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
|
||||
int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
|
||||
if (aDim != -1 && bDim != -1 && aDim != bDim)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
if (!aT && !uaT)
|
||||
return false;
|
||||
if (!bT && !ubT)
|
||||
return false;
|
||||
// Unranked to unranked casting is unsupported
|
||||
if (uaT && ubT)
|
||||
return false;
|
||||
|
||||
auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
|
||||
auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
|
||||
if (aEltType != bEltType)
|
||||
return false;
|
||||
|
||||
auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
|
||||
auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
|
||||
if (aMemSpace != bMemSpace)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -1086,6 +1086,13 @@ void ModulePrinter::printType(Type type) {
|
|||
os << '>';
|
||||
return;
|
||||
}
|
||||
case StandardTypes::UnrankedMemRef: {
|
||||
auto v = type.cast<UnrankedMemRefType>();
|
||||
os << "memref<*x";
|
||||
printType(v.getElementType());
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
case StandardTypes::Complex:
|
||||
os << "complex<";
|
||||
printType(type.cast<ComplexType>().getElementType());
|
||||
|
|
|
@ -90,8 +90,8 @@ struct BuiltinDialect : public Dialect {
|
|||
UnknownLoc>();
|
||||
|
||||
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
|
||||
MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
|
||||
UnrankedTensorType, VectorType>();
|
||||
MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
|
||||
RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
|
||||
|
||||
// TODO: These operations should be moved to a different dialect when they
|
||||
// have been fully decoupled from the core.
|
||||
|
|
|
@ -390,6 +390,37 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
|||
|
||||
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnrankedMemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
||||
unsigned memorySpace) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
|
||||
elementType, memorySpace);
|
||||
}
|
||||
|
||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
||||
unsigned memorySpace,
|
||||
Location location) {
|
||||
return Base::getChecked(location, elementType.getContext(),
|
||||
StandardTypes::UnrankedMemRef, elementType,
|
||||
memorySpace);
|
||||
}
|
||||
|
||||
unsigned UnrankedMemRefType::getMemorySpace() const {
|
||||
return getImpl()->memorySpace;
|
||||
}
|
||||
|
||||
LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Type elementType,
|
||||
unsigned memorySpace) {
|
||||
// Check that memref is formed from allowed types.
|
||||
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
|
||||
return emitOptionalError(*loc, "invalid memref element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Given MemRef `sizes` that are either static or dynamic, returns the
|
||||
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
|
||||
/// once a dynamic dimension is encountered, all canonical strides become
|
||||
|
|
|
@ -119,8 +119,8 @@ struct FunctionTypeStorage : public TypeStorage {
|
|||
|
||||
/// Shaped Type Storage.
|
||||
struct ShapedTypeStorage : public TypeStorage {
|
||||
ShapedTypeStorage(Type elementType, unsigned subclassData = 0)
|
||||
: TypeStorage(subclassData), elementType(elementType) {}
|
||||
ShapedTypeStorage(Type elementTy, unsigned subclassData = 0)
|
||||
: TypeStorage(subclassData), elementType(elementTy) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = Type;
|
||||
|
@ -252,6 +252,31 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
|
|||
const unsigned memorySpace;
|
||||
};
|
||||
|
||||
/// Unranked MemRef is a MemRef with unknown rank.
|
||||
/// Only element type and memory space are known
|
||||
struct UnrankedMemRefTypeStorage : public ShapedTypeStorage {
|
||||
|
||||
UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace)
|
||||
: ShapedTypeStorage(elementTy), memorySpace(memorySpace) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = std::tuple<Type, unsigned>;
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(elementType, memorySpace);
|
||||
}
|
||||
|
||||
/// Construction.
|
||||
static UnrankedMemRefTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
return new (allocator.allocate<UnrankedMemRefTypeStorage>())
|
||||
UnrankedMemRefTypeStorage(std::get<0>(key), std::get<1>(key));
|
||||
}
|
||||
/// Memory space in which data referenced by memref resides.
|
||||
const unsigned memorySpace;
|
||||
};
|
||||
|
||||
/// Complex Type Storage.
|
||||
struct ComplexTypeStorage : public TypeStorage {
|
||||
ComplexTypeStorage(Type elementType) : elementType(elementType) {}
|
||||
|
|
|
@ -1054,8 +1054,13 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
|
|||
|
||||
/// Parse a memref type.
|
||||
///
|
||||
/// memref-type ::= `memref` `<` dimension-list-ranked type
|
||||
/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
|
||||
/// memref-type ::= ranked-memref-type | unranked-memref-type
|
||||
///
|
||||
/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
|
||||
/// (`,` semi-affine-map-composition)? (`,`
|
||||
/// memory-space)? `>`
|
||||
///
|
||||
/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
|
||||
///
|
||||
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
||||
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
||||
|
@ -1066,9 +1071,20 @@ Type Parser::parseMemRefType() {
|
|||
if (parseToken(Token::less, "expected '<' in memref type"))
|
||||
return nullptr;
|
||||
|
||||
bool isUnranked;
|
||||
SmallVector<int64_t, 4> dimensions;
|
||||
if (parseDimensionListRanked(dimensions))
|
||||
return nullptr;
|
||||
|
||||
if (consumeIf(Token::star)) {
|
||||
// This is an unranked memref type.
|
||||
isUnranked = true;
|
||||
if (parseXInDimensionList())
|
||||
return nullptr;
|
||||
|
||||
} else {
|
||||
isUnranked = false;
|
||||
if (parseDimensionListRanked(dimensions))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Parse the element type.
|
||||
auto typeLoc = getToken().getLoc();
|
||||
|
@ -1093,6 +1109,8 @@ Type Parser::parseMemRefType() {
|
|||
consumeToken(Token::integer);
|
||||
parsedMemorySpace = true;
|
||||
} else {
|
||||
if (isUnranked)
|
||||
return emitError("cannot have affine map for unranked memref type");
|
||||
if (parsedMemorySpace)
|
||||
return emitError("expected memory space to be last in memref type");
|
||||
if (getToken().is(Token::kw_offset)) {
|
||||
|
@ -1131,6 +1149,10 @@ Type Parser::parseMemRefType() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (isUnranked)
|
||||
return UnrankedMemRefType::getChecked(elementType, memorySpace,
|
||||
getEncodedSourceLocation(typeLoc));
|
||||
|
||||
return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
|
||||
memorySpace, getEncodedSourceLocation(typeLoc));
|
||||
}
|
||||
|
|
|
@ -371,6 +371,30 @@ func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
|
||||
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
|
||||
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
|
||||
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
|
||||
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
|
||||
// CHECK-DAG: llvm.insertvalue %[[p2]], %{{.*}}[1] : !llvm<"{ i64, i8* }">
|
||||
%0 = memref_cast %arg : memref<42x2x?xf32> to memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
|
||||
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
|
||||
// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
|
||||
// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
|
||||
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
|
||||
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
|
||||
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
|
||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
|
||||
|
|
|
@ -565,6 +565,11 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64
|
|||
// CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]>
|
||||
%3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>
|
||||
|
||||
// CHECK: memref_cast %{{.*}} : memref<4xf32> to memref<*xf32>
|
||||
%4 = memref_cast %1 : memref<4xf32> to memref<*xf32>
|
||||
|
||||
// CHECK: memref_cast %{{.*}} : memref<*xf32> to memref<4xf32>
|
||||
%5 = memref_cast %4 : memref<*xf32> to memref<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -978,3 +978,34 @@ func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16,
|
|||
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// incompatible element types
|
||||
func @invalid_memref_cast() {
|
||||
%0 = alloc() : memref<2x5xf32, 0>
|
||||
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}}
|
||||
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// incompatible memory space
|
||||
func @invalid_memref_cast() {
|
||||
%0 = alloc() : memref<2x5xf32, 0>
|
||||
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32>' are cast incompatible}}
|
||||
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// unranked to unranked
|
||||
func @invalid_memref_cast() {
|
||||
%0 = alloc() : memref<2x5xf32, 0>
|
||||
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
|
||||
// expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
|
||||
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -78,6 +78,12 @@ template <typename T> struct StridedMemRefType<T, 0> {
|
|||
int64_t offset;
|
||||
};
|
||||
|
||||
// Unranked MemRef
|
||||
struct UnrankedMemRefType {
|
||||
int64_t rank;
|
||||
void *descriptor;
|
||||
};
|
||||
|
||||
template <typename StreamType, typename T, int N>
|
||||
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
|
||||
static_assert(N > 0, "Expected N > 0");
|
||||
|
@ -97,6 +103,15 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
|
|||
<< " offset = " << V.offset;
|
||||
}
|
||||
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_f32(UnrankedMemRefType *M);
|
||||
|
||||
template <typename StreamType>
|
||||
void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) {
|
||||
os << "Unranked Memref rank = " << V.rank << " "
|
||||
<< "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
|
||||
}
|
||||
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
|
|
|
@ -148,15 +148,41 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
|
|||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T> void printZeroDMemRef(StridedMemRefType<T, 0> &M) {
|
||||
template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
|
||||
std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
|
||||
<< " offset = " << M.offset << " data = [";
|
||||
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
||||
printMemRef(*M);
|
||||
}
|
||||
|
||||
extern "C" void print_memref_f32(UnrankedMemRefType *M) {
|
||||
printUnrankedMemRefMetaData(std::cout, *M);
|
||||
int rank = M->rank;
|
||||
void *ptr = M->descriptor;
|
||||
|
||||
#define MEMREF_CASE(RANK) \
|
||||
case RANK: \
|
||||
printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr))); \
|
||||
break
|
||||
|
||||
switch (rank) {
|
||||
MEMREF_CASE(0);
|
||||
MEMREF_CASE(1);
|
||||
MEMREF_CASE(2);
|
||||
MEMREF_CASE(3);
|
||||
MEMREF_CASE(4);
|
||||
default:
|
||||
assert(0 && "Unsupported rank to print");
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
|
||||
printZeroDMemRef(*M);
|
||||
printMemRef(*M);
|
||||
}
|
||||
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
|
||||
printMemRef(*M);
|
||||
|
@ -170,8 +196,3 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
|
|||
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
||||
printMemRef(*M);
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
||||
printMemRef(*M);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
|
||||
|
||||
// CHECK: rank = 2
|
||||
// CHECK-SAME: sizes = [10, 3]
|
||||
// CHECK-SAME: strides = [3, 1]
|
||||
// CHECK-COUNT-10: [10, 10, 10]
|
||||
//
|
||||
// CHECK: rank = 2
|
||||
// CHECK-SAME: sizes = [10, 3]
|
||||
// CHECK-SAME: strides = [3, 1]
|
||||
// CHECK-COUNT-10: [5, 5, 5]
|
||||
//
|
||||
// CHECK: rank = 2
|
||||
// CHECK-SAME: sizes = [10, 3]
|
||||
// CHECK-SAME: strides = [3, 1]
|
||||
// CHECK-COUNT-10: [2, 2, 2]
|
||||
func @main() -> () {
|
||||
%A = alloc() : memref<10x3xf32, 0>
|
||||
%f2 = constant 2.00000e+00 : f32
|
||||
%f5 = constant 5.00000e+00 : f32
|
||||
%f10 = constant 10.00000e+00 : f32
|
||||
|
||||
%V = memref_cast %A : memref<10x3xf32, 0> to memref<?x?xf32>
|
||||
linalg.fill(%V, %f10) : memref<?x?xf32, 0>, f32
|
||||
%U = memref_cast %A : memref<10x3xf32, 0> to memref<*xf32>
|
||||
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||
|
||||
%V2 = memref_cast %U : memref<*xf32> to memref<?x?xf32>
|
||||
linalg.fill(%V2, %f5) : memref<?x?xf32, 0>, f32
|
||||
%U2 = memref_cast %V2 : memref<?x?xf32, 0> to memref<*xf32>
|
||||
call @print_memref_f32(%U2) : (memref<*xf32>) -> ()
|
||||
|
||||
%V3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
|
||||
%V4 = memref_cast %V3 : memref<*xf32> to memref<?x?xf32>
|
||||
linalg.fill(%V4, %f2) : memref<?x?xf32, 0>, f32
|
||||
%U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
|
||||
|
||||
dealloc %A : memref<10x3xf32, 0>
|
||||
return
|
||||
}
|
||||
|
||||
func @print_memref_f32(memref<*xf32>)
|
|
@ -7,7 +7,8 @@ func @print_0d() {
|
|||
%f = constant 2.00000e+00 : f32
|
||||
%A = alloc() : memref<f32>
|
||||
store %f, %A[]: memref<f32>
|
||||
call @print_memref_0d_f32(%A): (memref<f32>) -> ()
|
||||
%U = memref_cast %A : memref<f32> to memref<*xf32>
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
dealloc %A : memref<f32>
|
||||
return
|
||||
}
|
||||
|
@ -18,7 +19,8 @@ func @print_1d() {
|
|||
%A = alloc() : memref<16xf32>
|
||||
%B = memref_cast %A: memref<16xf32> to memref<?xf32>
|
||||
linalg.fill(%B, %f) : memref<?xf32>, f32
|
||||
call @print_memref_1d_f32(%B): (memref<?xf32>) -> ()
|
||||
%U = memref_cast %B : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
dealloc %A : memref<16xf32>
|
||||
return
|
||||
}
|
||||
|
@ -34,8 +36,8 @@ func @print_3d() {
|
|||
|
||||
%c2 = constant 2 : index
|
||||
store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>
|
||||
|
||||
call @print_memref_3d_f32(%B): (memref<?x?x?xf32>) -> ()
|
||||
%U = memref_cast %B : memref<?x?x?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||
dealloc %A : memref<3x4x5xf32>
|
||||
return
|
||||
}
|
||||
|
@ -46,10 +48,7 @@ func @print_3d() {
|
|||
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
|
||||
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
|
||||
|
||||
func @print_memref_0d_f32(memref<f32>)
|
||||
func @print_memref_1d_f32(memref<?xf32>)
|
||||
func @print_memref_3d_f32(memref<?x?x?xf32>)
|
||||
|
||||
func @print_memref_f32(memref<*xf32>)
|
||||
|
||||
!vector_type_C = type vector<4x4xf32>
|
||||
!matrix_type_CC = type memref<1x1x!vector_type_C>
|
||||
|
|
|
@ -22,9 +22,10 @@ func @main() {
|
|||
store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32>
|
||||
gpu.return
|
||||
}
|
||||
call @print_memref_3d_f32(%dst) : (memref<?x?x?xf32>) -> ()
|
||||
%U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
|
||||
func @print_memref_3d_f32(%ptr : memref<?x?x?xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
|
|
@ -20,9 +20,10 @@ func @main() {
|
|||
store %res, %kernel_dst[%tx] : memref<?xf32>
|
||||
gpu.return
|
||||
}
|
||||
call @print_memref_1d_f32(%dst) : (memref<?xf32>) -> ()
|
||||
%U = memref_cast %dst : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
|
||||
func @print_memref_1d_f32(memref<?xf32>)
|
||||
func @print_memref_f32(memref<*xf32>)
|
||||
|
|
Loading…
Reference in New Issue