forked from OSchip/llvm-project
Add UnrankedMemRef Type
Closes tensorflow/mlir#261 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/261 from nmostafa:nmostafa/unranked 96b6e918f6ed64496f7573b2db33c0b02658ca45 PiperOrigin-RevId: 284037040
This commit is contained in:
parent
e67acfa468
commit
daff60cd68
|
@ -90,6 +90,23 @@ memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, float*, i64, [5 x i64], [5
|
||||||
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }">
|
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, <4 x float>*, i64, [1 x i64], [1 x i64] }">
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If the rank of the memref is unknown at compile time, the Memref is converted to
|
||||||
|
an unranked descriptor that contains: 1. a 64-bit integer representing the
|
||||||
|
dynamic rank of the memref, followed by 2. a pointer to a ranked memref
|
||||||
|
descriptor with the contents listed above.
|
||||||
|
|
||||||
|
Dynamic ranked memrefs should be used only to pass arguments to external library
|
||||||
|
calls that expect a unified memref type. The called functions can parse any
|
||||||
|
unranked memref descriptor by reading the rank and parsing the enclosed ranked
|
||||||
|
descriptor pointer.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```mlir {.mlir}
|
||||||
|
// unranked descriptor
|
||||||
|
memref<*xf32> -> !llvm.type<"{i64, i8*}">
|
||||||
|
```
|
||||||
|
|
||||||
### Function Types
|
### Function Types
|
||||||
|
|
||||||
Function types get converted to LLVM function types. The arguments are converted
|
Function types get converted to LLVM function types. The arguments are converted
|
||||||
|
|
|
@ -912,12 +912,21 @@ Examples:
|
||||||
|
|
||||||
// Convert to a type with more known dimensions.
|
// Convert to a type with more known dimensions.
|
||||||
%4 = memref_cast %3 : memref<?x?xf32> to memref<4x?xf32>
|
%4 = memref_cast %3 : memref<?x?xf32> to memref<4x?xf32>
|
||||||
|
|
||||||
|
// Convert to a type with unknown rank.
|
||||||
|
%5 = memref_cast %3 : memref<?x?xf32> to memref<*xf32>
|
||||||
|
|
||||||
|
// Convert to a type with static rank.
|
||||||
|
%6 = memref_cast %5 : memref<*xf32> to memref<?x?xf32>
|
||||||
```
|
```
|
||||||
|
|
||||||
Convert a memref from one type to an equivalent type without changing any data
|
Convert a memref from one type to an equivalent type without changing any data
|
||||||
elements. The source and destination types must both be memref types with the
|
elements. The types are equivalent if 1. they both have the same static rank,
|
||||||
same element type, same mappings, same address space, and same rank. The
|
same element type, same mappings, same address space. The operation is invalid
|
||||||
operation is invalid if converting to a mismatching constant dimension.
|
if converting to a mismatching constant dimension, or 2. exactly one of the
|
||||||
|
operands have an unknown rank, and they both have the same element type and same
|
||||||
|
address space. The operation is invalid if both operands are of dynamic rank or
|
||||||
|
if converting to a mismatching static rank.
|
||||||
|
|
||||||
### 'mulf' operation
|
### 'mulf' operation
|
||||||
|
|
||||||
|
|
|
@ -760,10 +760,16 @@ TODO: Need to decide on a representation for quantized integers
|
||||||
Syntax:
|
Syntax:
|
||||||
|
|
||||||
``` {.ebnf}
|
``` {.ebnf}
|
||||||
memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
|
|
||||||
|
memref-type ::= ranked-memref-type | unranked-memref-type
|
||||||
|
|
||||||
|
ranked-memref-type ::= `memref` `<` dimension-list-ranked tensor-memref-element-type
|
||||||
(`,` layout-specification)? |
|
(`,` layout-specification)? |
|
||||||
(`,` memory-space)? `>`
|
(`,` memory-space)? `>`
|
||||||
|
|
||||||
|
unranked-memref-type ::= `memref` `<*x` tensor-memref-element-type
|
||||||
|
(`,` memory-space)? `>`
|
||||||
|
|
||||||
stride-list ::= `[` (dimension (`,` dimension)*)? `]`
|
stride-list ::= `[` (dimension (`,` dimension)*)? `]`
|
||||||
strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
|
strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
|
||||||
layout-specification ::= semi-affine-map | strided-layout
|
layout-specification ::= semi-affine-map | strided-layout
|
||||||
|
@ -774,9 +780,48 @@ A `memref` type is a reference to a region of memory (similar to a buffer
|
||||||
pointer, but more powerful). The buffer pointed to by a memref can be allocated,
|
pointer, but more powerful). The buffer pointed to by a memref can be allocated,
|
||||||
aliased and deallocated. A memref can be used to read and write data from/to the
|
aliased and deallocated. A memref can be used to read and write data from/to the
|
||||||
memory region which it references. Memref types use the same shape specifier as
|
memory region which it references. Memref types use the same shape specifier as
|
||||||
tensor types, but do not allow unknown rank. Note that `memref<f32>`, `memref<0
|
tensor types. Note that `memref<f32>`, `memref<0 x f32>`, `memref<1 x 0 x f32>`,
|
||||||
x f32>`, `memref<1 x 0 x f32>`, and `memref<0 x 1 x f32>` are all different
|
and `memref<0 x 1 x f32>` are all different types.
|
||||||
types.
|
|
||||||
|
A `memref` is allowed to have an unknown rank (e.g. `memref<*xf32>`). The
|
||||||
|
purpose of unranked memrefs is to allow external library functions to receive
|
||||||
|
memref arguments of any rank without versioning the functions based on the rank.
|
||||||
|
Other uses of this type are disallowed or will have undefined behavior.
|
||||||
|
|
||||||
|
##### Codegen of Unranked Memref
|
||||||
|
|
||||||
|
Using unranked memref in codegen besides the case mentioned above is highly
|
||||||
|
discouraged. Codegen is concerned with generating loop nests and specialized
|
||||||
|
instructions for high-performance, unranked memref is concerned with hiding the
|
||||||
|
rank and thus, the number of enclosing loops required to iterate over the data.
|
||||||
|
However, if there is a need to code-gen unranked memref, one possible path is to
|
||||||
|
cast into a static ranked type based on the dynamic rank. Another possible path
|
||||||
|
is to emit a single while loop conditioned on a linear index and perform
|
||||||
|
delinearization of the linear index to a dynamic array containing the (unranked)
|
||||||
|
indices. While this is possible, it is expected to not be a good idea to perform
|
||||||
|
this during codegen as the cost of the translations is expected to be
|
||||||
|
prohibitive and optimizations at this level are not expected to be worthwhile.
|
||||||
|
If expressiveness is the main concern, irrespective of performance, passing
|
||||||
|
unranked memrefs to an external C++ library and implementing rank-agnostic logic
|
||||||
|
there is expected to be significantly simpler.
|
||||||
|
|
||||||
|
Unranked memrefs may provide expressiveness gains in the future and help bridge
|
||||||
|
the gap with unranked tensors. Unranked memrefs will not be expected to be
|
||||||
|
exposed to codegen but one may query the rank of an unranked memref (a special
|
||||||
|
op will be needed for this purpose) and perform a switch and cast to a ranked
|
||||||
|
memref as a prerequisite to codegen.
|
||||||
|
|
||||||
|
Example ```mlir {.mlir} // With static ranks, we need a function for each
|
||||||
|
possible argument type %A = alloc() : memref<16x32xf32> %B = alloc() :
|
||||||
|
memref<16x32x64xf32> call @helper_2D(%A) : (memref<16x32xf32>)->() call
|
||||||
|
@helper_3D(%B) : (memref<16x32x64xf32>)->()
|
||||||
|
|
||||||
|
// With unknown rank, the functions can be unified under one unranked type %A =
|
||||||
|
alloc() : memref<16x32xf32> %B = alloc() : memref<16x32x64xf32> // Remove rank
|
||||||
|
info %A_u = memref_cast %A : memref<16x32xf32> -> memref<*xf32> %B_u =
|
||||||
|
memref_cast %B : memref<16x32x64xf32> -> memref<*xf32> // call same function
|
||||||
|
with dynamic ranks call @helper(%A_u) : (memref<*xf32>)->() call @helper(%B_u) :
|
||||||
|
(memref<*xf32>)->() ```
|
||||||
|
|
||||||
The core syntax and representation of a layout specification is a
|
The core syntax and representation of a layout specification is a
|
||||||
[semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic
|
[semi-affine map](Dialects/Affine.md#semi-affine-maps). Additionally, syntactic
|
||||||
|
|
|
@ -34,6 +34,9 @@ class Type;
|
||||||
} // namespace llvm
|
} // namespace llvm
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
class UnrankedMemRefType;
|
||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
class LLVMDialect;
|
class LLVMDialect;
|
||||||
class LLVMType;
|
class LLVMType;
|
||||||
|
@ -116,6 +119,10 @@ private:
|
||||||
// 2. as many index types as memref has dynamic dimensions.
|
// 2. as many index types as memref has dynamic dimensions.
|
||||||
Type convertMemRefType(MemRefType type);
|
Type convertMemRefType(MemRefType type);
|
||||||
|
|
||||||
|
// Convert an unranked memref type to an LLVM type that captures the
|
||||||
|
// runtime rank and a pointer to the static ranked memref desc
|
||||||
|
Type convertUnrankedMemRefType(UnrankedMemRefType type);
|
||||||
|
|
||||||
// Convert a 1D vector type into an LLVM vector type.
|
// Convert a 1D vector type into an LLVM vector type.
|
||||||
Type convertVectorType(VectorType type);
|
Type convertVectorType(VectorType type);
|
||||||
|
|
||||||
|
@ -127,10 +134,34 @@ private:
|
||||||
LLVM::LLVMType unwrap(Type type);
|
LLVM::LLVMType unwrap(Type type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||||
|
/// values to a struct.
|
||||||
|
class StructBuilder {
|
||||||
|
public:
|
||||||
|
/// Construct a helper for the given value.
|
||||||
|
explicit StructBuilder(Value *v);
|
||||||
|
/// Builds IR creating an `undef` value of the descriptor type.
|
||||||
|
static StructBuilder undef(OpBuilder &builder, Location loc,
|
||||||
|
Type descriptorType);
|
||||||
|
|
||||||
|
/*implicit*/ operator Value *() { return value; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// LLVM value
|
||||||
|
Value *value;
|
||||||
|
// Cached struct type.
|
||||||
|
Type structType;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Builds IR to extract a value from the struct at position pos
|
||||||
|
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
|
||||||
|
/// Builds IR to set a value in the struct at position pos
|
||||||
|
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
|
||||||
|
};
|
||||||
/// Helper class to produce LLVM dialect operations extracting or inserting
|
/// Helper class to produce LLVM dialect operations extracting or inserting
|
||||||
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
|
||||||
/// The Value may be null, in which case none of the operations are valid.
|
/// The Value may be null, in which case none of the operations are valid.
|
||||||
class MemRefDescriptor {
|
class MemRefDescriptor : public StructBuilder {
|
||||||
public:
|
public:
|
||||||
/// Construct a helper for the given descriptor value.
|
/// Construct a helper for the given descriptor value.
|
||||||
explicit MemRefDescriptor(Value *descriptor);
|
explicit MemRefDescriptor(Value *descriptor);
|
||||||
|
@ -169,22 +200,28 @@ public:
|
||||||
/// Returns the (LLVM) type this descriptor points to.
|
/// Returns the (LLVM) type this descriptor points to.
|
||||||
LLVM::LLVMType getElementType();
|
LLVM::LLVMType getElementType();
|
||||||
|
|
||||||
/*implicit*/ operator Value *() { return value; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
|
|
||||||
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
|
|
||||||
|
|
||||||
// Cached descriptor type.
|
|
||||||
Type structType;
|
|
||||||
|
|
||||||
// Cached index type.
|
// Cached index type.
|
||||||
Type indexType;
|
Type indexType;
|
||||||
|
|
||||||
// Actual descriptor.
|
|
||||||
Value *value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UnrankedMemRefDescriptor : public StructBuilder {
|
||||||
|
public:
|
||||||
|
/// Construct a helper for the given descriptor value.
|
||||||
|
explicit UnrankedMemRefDescriptor(Value *descriptor);
|
||||||
|
/// Builds IR creating an `undef` value of the descriptor type.
|
||||||
|
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
|
||||||
|
Type descriptorType);
|
||||||
|
|
||||||
|
/// Builds IR extracting the rank from the descriptor
|
||||||
|
Value *rank(OpBuilder &builder, Location loc);
|
||||||
|
/// Builds IR setting the rank in the descriptor
|
||||||
|
void setRank(OpBuilder &builder, Location loc, Value *value);
|
||||||
|
/// Builds IR extracting ranked memref descriptor ptr
|
||||||
|
Value *memRefDescPtr(OpBuilder &builder, Location loc);
|
||||||
|
/// Builds IR setting ranked memref descriptor ptr
|
||||||
|
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value *value);
|
||||||
|
};
|
||||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||||
/// conversion patterns with an access to the containing LLVMLowering for the
|
/// conversion patterns with an access to the containing LLVMLowering for the
|
||||||
/// purpose of type conversions.
|
/// purpose of type conversions.
|
||||||
|
|
|
@ -842,7 +842,8 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
The "memref_cast" operation converts a memref from one type to an equivalent
|
The "memref_cast" operation converts a memref from one type to an equivalent
|
||||||
type with a compatible shape. The source and destination types are
|
type with a compatible shape. The source and destination types are
|
||||||
when both are memref types with the same element type, affine mappings,
|
compatible if:
|
||||||
|
a. both are ranked memref types with the same element type, affine mappings,
|
||||||
address space, and rank but where the individual dimensions may add or
|
address space, and rank but where the individual dimensions may add or
|
||||||
remove constant dimensions from the memref type.
|
remove constant dimensions from the memref type.
|
||||||
|
|
||||||
|
@ -850,6 +851,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
||||||
acts as an assertion that fails at runtime of the dynamic dimensions
|
acts as an assertion that fails at runtime of the dynamic dimensions
|
||||||
disagree with resultant destination size.
|
disagree with resultant destination size.
|
||||||
|
|
||||||
|
Example:
|
||||||
Assert that the input dynamic shape matches the destination static shape.
|
Assert that the input dynamic shape matches the destination static shape.
|
||||||
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
|
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
|
||||||
Erase static shape information, replacing it with dynamic information.
|
Erase static shape information, replacing it with dynamic information.
|
||||||
|
@ -864,10 +866,20 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
||||||
dynamic information.
|
dynamic information.
|
||||||
%5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to
|
%5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to
|
||||||
memref<12x4xf32, offset:?, strides: [?, ?]>
|
memref<12x4xf32, offset:?, strides: [?, ?]>
|
||||||
|
|
||||||
|
b. either or both memref types are unranked with the same element type, and
|
||||||
|
address space.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Cast to concrete shape.
|
||||||
|
%4 = memref_cast %1 : memref<*xf32> to memref<4x?xf32>
|
||||||
|
|
||||||
|
Erase rank information.
|
||||||
|
%5 = memref_cast %1 : memref<4x?xf32> to memref<*xf32>
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyMemRef:$source);
|
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
|
||||||
let results = (outs AnyMemRef);
|
let results = (outs AnyRankedOrUnrankedMemRef);
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||||
|
@ -875,7 +887,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
||||||
static bool areCastCompatible(Type a, Type b);
|
static bool areCastCompatible(Type a, Type b);
|
||||||
|
|
||||||
/// The result of a memref_cast is always a memref.
|
/// The result of a memref_cast is always a memref.
|
||||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
Type getType() { return getResult()->getType(); }
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -221,6 +221,9 @@ def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
|
||||||
// Whether a type is a MemRefType.
|
// Whether a type is a MemRefType.
|
||||||
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
||||||
|
|
||||||
|
// Whether a type is an IsUnrankedMemRefType
|
||||||
|
def IsUnrankedMemRefTypePred : CPred<"$_self.isa<UnrankedMemRefType>()">;
|
||||||
|
|
||||||
// Whether a type is a ShapedType.
|
// Whether a type is a ShapedType.
|
||||||
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
|
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
|
||||||
|
|
||||||
|
@ -486,6 +489,10 @@ class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
|
||||||
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
|
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
|
||||||
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
|
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
|
||||||
|
|
||||||
|
// Unranked Memref type
|
||||||
|
def AnyUnrankedMemRef :
|
||||||
|
ShapedContainerType<[AnyType],
|
||||||
|
IsUnrankedMemRefTypePred, "unranked.memref">;
|
||||||
// Memref type.
|
// Memref type.
|
||||||
|
|
||||||
// Memrefs are blocks of data with fixed type and rank.
|
// Memrefs are blocks of data with fixed type and rank.
|
||||||
|
@ -494,6 +501,8 @@ class MemRefOf<list<Type> allowedTypes> :
|
||||||
|
|
||||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||||
|
|
||||||
|
def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
|
||||||
|
|
||||||
// Memref declarations handle any memref, independent of rank, size, (static or
|
// Memref declarations handle any memref, independent of rank, size, (static or
|
||||||
// dynamic), layout, or memory space.
|
// dynamic), layout, or memory space.
|
||||||
def I1MemRef : MemRefOf<[I1]>;
|
def I1MemRef : MemRefOf<[I1]>;
|
||||||
|
|
|
@ -40,6 +40,7 @@ struct VectorTypeStorage;
|
||||||
struct RankedTensorTypeStorage;
|
struct RankedTensorTypeStorage;
|
||||||
struct UnrankedTensorTypeStorage;
|
struct UnrankedTensorTypeStorage;
|
||||||
struct MemRefTypeStorage;
|
struct MemRefTypeStorage;
|
||||||
|
struct UnrankedMemRefTypeStorage;
|
||||||
struct ComplexTypeStorage;
|
struct ComplexTypeStorage;
|
||||||
struct TupleTypeStorage;
|
struct TupleTypeStorage;
|
||||||
|
|
||||||
|
@ -64,6 +65,7 @@ enum Kind {
|
||||||
RankedTensor,
|
RankedTensor,
|
||||||
UnrankedTensor,
|
UnrankedTensor,
|
||||||
MemRef,
|
MemRef,
|
||||||
|
UnrankedMemRef,
|
||||||
Complex,
|
Complex,
|
||||||
Tuple,
|
Tuple,
|
||||||
None,
|
None,
|
||||||
|
@ -243,6 +245,7 @@ public:
|
||||||
return type.getKind() == StandardTypes::Vector ||
|
return type.getKind() == StandardTypes::Vector ||
|
||||||
type.getKind() == StandardTypes::RankedTensor ||
|
type.getKind() == StandardTypes::RankedTensor ||
|
||||||
type.getKind() == StandardTypes::UnrankedTensor ||
|
type.getKind() == StandardTypes::UnrankedTensor ||
|
||||||
|
type.getKind() == StandardTypes::UnrankedMemRef ||
|
||||||
type.getKind() == StandardTypes::MemRef;
|
type.getKind() == StandardTypes::MemRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,12 +373,24 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Base MemRef for Ranked and Unranked variants
|
||||||
|
class BaseMemRefType : public ShapedType {
|
||||||
|
public:
|
||||||
|
using ShapedType::ShapedType;
|
||||||
|
|
||||||
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
|
static bool classof(Type type) {
|
||||||
|
return type.getKind() == StandardTypes::MemRef ||
|
||||||
|
type.getKind() == StandardTypes::UnrankedMemRef;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||||
/// number of dimensions. Each shape element can be a non-negative integer or
|
/// number of dimensions. Each shape element can be a non-negative integer or
|
||||||
/// unknown (represented by any negative integer). MemRef types also have an
|
/// unknown (represented by any negative integer). MemRef types also have an
|
||||||
/// affine map composition, represented as an array AffineMap pointers.
|
/// affine map composition, represented as an array AffineMap pointers.
|
||||||
class MemRefType
|
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
|
||||||
: public Type::TypeBase<MemRefType, ShapedType, detail::MemRefTypeStorage> {
|
detail::MemRefTypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
|
@ -426,6 +441,40 @@ private:
|
||||||
using Base::getImpl;
|
using Base::getImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Unranked MemRef type represent multi-dimensional MemRefs that
|
||||||
|
/// have an unknown rank.
|
||||||
|
class UnrankedMemRefType
|
||||||
|
: public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
|
||||||
|
detail::UnrankedMemRefTypeStorage> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
|
||||||
|
/// Get or create a new UnrankedMemRefType of the provided element
|
||||||
|
/// type and memory space
|
||||||
|
static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
|
||||||
|
|
||||||
|
/// Get or create a new UnrankedMemRefType of the provided element
|
||||||
|
/// type and memory space declared at the given, potentially unknown,
|
||||||
|
/// location. If the UnrankedMemRefType defined by the arguments would be
|
||||||
|
/// ill-formed, emit errors and return a nullptr-wrapping type.
|
||||||
|
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
|
||||||
|
Location location);
|
||||||
|
|
||||||
|
/// Verify the construction of a unranked memref type.
|
||||||
|
static LogicalResult
|
||||||
|
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||||
|
MLIRContext *context, Type elementType,
|
||||||
|
unsigned memorySpace);
|
||||||
|
|
||||||
|
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
||||||
|
|
||||||
|
/// Returns the memory space in which data referred to by this memref resides.
|
||||||
|
unsigned getMemorySpace() const;
|
||||||
|
static bool kindof(unsigned kind) {
|
||||||
|
return kind == StandardTypes::UnrankedMemRef;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Tuple types represent a collection of other types. Note: This type merely
|
/// Tuple types represent a collection of other types. Note: This type merely
|
||||||
/// provides a common mechanism for representing tuples in MLIR. It is up to
|
/// provides a common mechanism for representing tuples in MLIR. It is up to
|
||||||
/// dialect authors to provides operations for manipulating them, e.g.
|
/// dialect authors to provides operations for manipulating them, e.g.
|
||||||
|
|
|
@ -193,6 +193,22 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
||||||
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
|
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
|
||||||
|
// contains:
|
||||||
|
// 1. int64_t rank, the dynamic rank of this MemRef
|
||||||
|
// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
|
||||||
|
// stack allocated (alloca) copy of a MemRef descriptor that got casted to
|
||||||
|
// be unranked.
|
||||||
|
|
||||||
|
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
|
||||||
|
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
|
||||||
|
|
||||||
|
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
||||||
|
auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
|
||||||
|
}
|
||||||
|
|
||||||
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
|
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
|
||||||
// n > 1.
|
// n > 1.
|
||||||
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
|
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
|
||||||
|
@ -221,6 +237,8 @@ Type LLVMTypeConverter::convertStandardType(Type type) {
|
||||||
return convertIndexType(indexType);
|
return convertIndexType(indexType);
|
||||||
if (auto memRefType = type.dyn_cast<MemRefType>())
|
if (auto memRefType = type.dyn_cast<MemRefType>())
|
||||||
return convertMemRefType(memRefType);
|
return convertMemRefType(memRefType);
|
||||||
|
if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||||
|
return convertUnrankedMemRefType(memRefType);
|
||||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||||
return convertVectorType(vectorType);
|
return convertVectorType(vectorType);
|
||||||
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
|
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
|
||||||
|
@ -245,22 +263,42 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
||||||
PatternBenefit benefit)
|
PatternBenefit benefit)
|
||||||
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* StructBuilder implementation */
|
||||||
|
/*============================================================================*/
|
||||||
|
StructBuilder::StructBuilder(Value *v) : value(v) {
|
||||||
|
assert(value != nullptr && "value cannot be null");
|
||||||
|
structType = value->getType().cast<LLVM::LLVMType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc,
|
||||||
|
unsigned pos) {
|
||||||
|
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
|
||||||
|
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
||||||
|
builder.getI64ArrayAttr(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
||||||
|
Value *ptr) {
|
||||||
|
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
||||||
|
builder.getI64ArrayAttr(pos));
|
||||||
|
}
|
||||||
/*============================================================================*/
|
/*============================================================================*/
|
||||||
/* MemRefDescriptor implementation */
|
/* MemRefDescriptor implementation */
|
||||||
/*============================================================================*/
|
/*============================================================================*/
|
||||||
|
|
||||||
/// Construct a helper for the given descriptor value.
|
/// Construct a helper for the given descriptor value.
|
||||||
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
|
MemRefDescriptor::MemRefDescriptor(Value *descriptor)
|
||||||
if (value) {
|
: StructBuilder(descriptor) {
|
||||||
structType = value->getType().cast<LLVM::LLVMType>();
|
assert(value != nullptr && "value cannot be null");
|
||||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||||
kOffsetPosInMemRefDescriptor);
|
kOffsetPosInMemRefDescriptor);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds IR creating an `undef` value of the descriptor type.
|
/// Builds IR creating an `undef` value of the descriptor type.
|
||||||
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
|
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
|
||||||
Type descriptorType) {
|
Type descriptorType) {
|
||||||
|
|
||||||
Value *descriptor =
|
Value *descriptor =
|
||||||
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
||||||
return MemRefDescriptor(descriptor);
|
return MemRefDescriptor(descriptor);
|
||||||
|
@ -334,24 +372,42 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
|
||||||
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
|
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
|
|
||||||
unsigned pos) {
|
|
||||||
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
|
|
||||||
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
|
||||||
builder.getI64ArrayAttr(pos));
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
|
||||||
Value *ptr) {
|
|
||||||
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
|
||||||
builder.getI64ArrayAttr(pos));
|
|
||||||
}
|
|
||||||
|
|
||||||
LLVM::LLVMType MemRefDescriptor::getElementType() {
|
LLVM::LLVMType MemRefDescriptor::getElementType() {
|
||||||
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||||
kAlignedPtrPosInMemRefDescriptor);
|
kAlignedPtrPosInMemRefDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* UnrankedMemRefDescriptor implementation */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/// Construct a helper for the given descriptor value.
|
||||||
|
UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor)
|
||||||
|
: StructBuilder(descriptor) {}
|
||||||
|
|
||||||
|
/// Builds IR creating an `undef` value of the descriptor type.
|
||||||
|
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
|
||||||
|
Location loc,
|
||||||
|
Type descriptorType) {
|
||||||
|
Value *descriptor =
|
||||||
|
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
|
||||||
|
return UnrankedMemRefDescriptor(descriptor);
|
||||||
|
}
|
||||||
|
Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
|
||||||
|
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
|
||||||
|
}
|
||||||
|
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
|
||||||
|
Value *v) {
|
||||||
|
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
|
||||||
|
}
|
||||||
|
Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
|
||||||
|
Location loc) {
|
||||||
|
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
|
||||||
|
}
|
||||||
|
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
|
||||||
|
Location loc, Value *v) {
|
||||||
|
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
|
||||||
|
}
|
||||||
namespace {
|
namespace {
|
||||||
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
// Base class for Standard to LLVM IR op conversions. Matches the Op type
|
||||||
// provided as template argument. Carries a reference to the LLVM dialect in
|
// provided as template argument. Carries a reference to the LLVM dialect in
|
||||||
|
@ -432,7 +488,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
|
||||||
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
|
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
|
||||||
if (!converted)
|
if (!converted)
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
if (t.isa<MemRefType>()) {
|
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) {
|
||||||
converted = converted.getPointerTo();
|
converted = converted.getPointerTo();
|
||||||
promotedArgIndices.push_back(en.index());
|
promotedArgIndices.push_back(en.index());
|
||||||
}
|
}
|
||||||
|
@ -983,6 +1039,14 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
||||||
Type packedResult;
|
Type packedResult;
|
||||||
unsigned numResults = callOp.getNumResults();
|
unsigned numResults = callOp.getNumResults();
|
||||||
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
||||||
|
|
||||||
|
for (Type resType : resultTypes) {
|
||||||
|
assert(!resType.isa<UnrankedMemRefType>() &&
|
||||||
|
"Returning unranked memref is not supported. Pass result as an"
|
||||||
|
"argument instead.");
|
||||||
|
(void)resType;
|
||||||
|
}
|
||||||
|
|
||||||
if (numResults != 0) {
|
if (numResults != 0) {
|
||||||
if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
|
if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
|
@ -1076,25 +1140,93 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
||||||
|
|
||||||
PatternMatchResult match(Operation *op) const override {
|
PatternMatchResult match(Operation *op) const override {
|
||||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||||
|
Type srcType = memRefCastOp.getOperand()->getType();
|
||||||
|
Type dstType = memRefCastOp.getType();
|
||||||
|
|
||||||
|
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
|
||||||
MemRefType sourceType =
|
MemRefType sourceType =
|
||||||
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||||
MemRefType targetType = memRefCastOp.getType();
|
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
|
||||||
return (isSupportedMemRefType(targetType) &&
|
return (isSupportedMemRefType(targetType) &&
|
||||||
isSupportedMemRefType(sourceType))
|
isSupportedMemRefType(sourceType))
|
||||||
? matchSuccess()
|
? matchSuccess()
|
||||||
: matchFailure();
|
: matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// At least one of the operands is unranked type
|
||||||
|
assert(srcType.isa<UnrankedMemRefType>() ||
|
||||||
|
dstType.isa<UnrankedMemRefType>());
|
||||||
|
|
||||||
|
// Unranked to unranked cast is disallowed
|
||||||
|
return !(srcType.isa<UnrankedMemRefType>() &&
|
||||||
|
dstType.isa<UnrankedMemRefType>())
|
||||||
|
? matchSuccess()
|
||||||
|
: matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||||
OperandAdaptor<MemRefCastOp> transformed(operands);
|
OperandAdaptor<MemRefCastOp> transformed(operands);
|
||||||
|
|
||||||
|
auto srcType = memRefCastOp.getOperand()->getType();
|
||||||
|
auto dstType = memRefCastOp.getType();
|
||||||
|
auto targetStructType = lowering.convertType(memRefCastOp.getType());
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
|
||||||
// memref_cast is defined for source and destination memref types with the
|
// memref_cast is defined for source and destination memref types with the
|
||||||
// same element type, same mappings, same address space and same rank.
|
// same element type, same mappings, same address space and same rank.
|
||||||
// Therefore a simple bitcast suffices. If not it is undefined behavior.
|
// Therefore a simple bitcast suffices. If not it is undefined behavior.
|
||||||
auto targetStructType = lowering.convertType(memRefCastOp.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
|
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
|
||||||
transformed.source());
|
transformed.source());
|
||||||
|
} else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
|
||||||
|
// Casting ranked to unranked memref type
|
||||||
|
// Set the rank in the destination from the memref type
|
||||||
|
// Allocate space on the stack and copy the src memref decsriptor
|
||||||
|
// Set the ptr in the destination to the stack space
|
||||||
|
auto srcMemRefType = srcType.cast<MemRefType>();
|
||||||
|
int64_t rank = srcMemRefType.getRank();
|
||||||
|
// ptr = AllocaOp sizeof(MemRefDescriptor)
|
||||||
|
auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
|
||||||
|
rewriter);
|
||||||
|
// voidptr = BitCastOp srcType* to void*
|
||||||
|
auto voidPtr =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
|
||||||
|
.getResult();
|
||||||
|
// rank = ConstantOp srcRank
|
||||||
|
auto rankVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, lowering.convertType(rewriter.getIntegerType(64)),
|
||||||
|
rewriter.getI64IntegerAttr(rank));
|
||||||
|
// undef = UndefOp
|
||||||
|
UnrankedMemRefDescriptor memRefDesc =
|
||||||
|
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
|
||||||
|
// d1 = InsertValueOp undef, rank, 0
|
||||||
|
memRefDesc.setRank(rewriter, loc, rankVal);
|
||||||
|
// d2 = InsertValueOp d1, voidptr, 1
|
||||||
|
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
|
||||||
|
rewriter.replaceOp(op, (Value *)memRefDesc);
|
||||||
|
|
||||||
|
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
|
||||||
|
// Casting from unranked type to ranked.
|
||||||
|
// The operation is assumed to be doing a correct cast. If the destination
|
||||||
|
// type mismatches the unranked the type, it is undefined behavior.
|
||||||
|
UnrankedMemRefDescriptor memRefDesc(transformed.source());
|
||||||
|
// ptr = ExtractValueOp src, 1
|
||||||
|
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
|
||||||
|
// castPtr = BitCastOp i8* to structTy*
|
||||||
|
auto castPtr =
|
||||||
|
rewriter
|
||||||
|
.create<LLVM::BitcastOp>(
|
||||||
|
loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
|
||||||
|
ptr)
|
||||||
|
.getResult();
|
||||||
|
// struct = LoadOp castPtr
|
||||||
|
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
|
||||||
|
rewriter.replaceOp(op, loadOp.getResult());
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unsuppored unranked memref to unranked memref cast");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1896,7 +2028,8 @@ SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
|
||||||
for (auto it : llvm::zip(opOperands, operands)) {
|
for (auto it : llvm::zip(opOperands, operands)) {
|
||||||
auto *operand = std::get<0>(it);
|
auto *operand = std::get<0>(it);
|
||||||
auto *llvmOperand = std::get<1>(it);
|
auto *llvmOperand = std::get<1>(it);
|
||||||
if (!operand->getType().isa<MemRefType>()) {
|
if (!operand->getType().isa<MemRefType>() &&
|
||||||
|
!operand->getType().isa<UnrankedMemRefType>()) {
|
||||||
promotedOperands.push_back(operand);
|
promotedOperands.push_back(operand);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1769,8 +1769,10 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
||||||
auto aT = a.dyn_cast<MemRefType>();
|
auto aT = a.dyn_cast<MemRefType>();
|
||||||
auto bT = b.dyn_cast<MemRefType>();
|
auto bT = b.dyn_cast<MemRefType>();
|
||||||
|
|
||||||
if (!aT || !bT)
|
auto uaT = a.dyn_cast<UnrankedMemRefType>();
|
||||||
return false;
|
auto ubT = b.dyn_cast<UnrankedMemRefType>();
|
||||||
|
|
||||||
|
if (aT && bT) {
|
||||||
if (aT.getElementType() != bT.getElementType())
|
if (aT.getElementType() != bT.getElementType())
|
||||||
return false;
|
return false;
|
||||||
if (aT.getAffineMaps() != bT.getAffineMaps()) {
|
if (aT.getAffineMaps() != bT.getAffineMaps()) {
|
||||||
|
@ -1783,8 +1785,8 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
||||||
|
|
||||||
// Strides along a dimension/offset are compatible if the value in the
|
// Strides along a dimension/offset are compatible if the value in the
|
||||||
// source memref is static and the value in the target memref is the
|
// source memref is static and the value in the target memref is the
|
||||||
// same. They are also compatible if either one is dynamic (see description
|
// same. They are also compatible if either one is dynamic (see
|
||||||
// of MemRefCastOp for details).
|
// description of MemRefCastOp for details).
|
||||||
auto checkCompatible = [](int64_t a, int64_t b) {
|
auto checkCompatible = [](int64_t a, int64_t b) {
|
||||||
return (a == MemRefType::getDynamicStrideOrOffset() ||
|
return (a == MemRefType::getDynamicStrideOrOffset() ||
|
||||||
b == MemRefType::getDynamicStrideOrOffset() || a == b);
|
b == MemRefType::getDynamicStrideOrOffset() || a == b);
|
||||||
|
@ -1807,8 +1809,30 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
||||||
if (aDim != -1 && bDim != -1 && aDim != bDim)
|
if (aDim != -1 && bDim != -1 && aDim != bDim)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
if (!aT && !uaT)
|
||||||
|
return false;
|
||||||
|
if (!bT && !ubT)
|
||||||
|
return false;
|
||||||
|
// Unranked to unranked casting is unsupported
|
||||||
|
if (uaT && ubT)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
|
||||||
|
auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
|
||||||
|
if (aEltType != bEltType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
|
||||||
|
auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
|
||||||
|
if (aMemSpace != bMemSpace)
|
||||||
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
|
@ -1086,6 +1086,13 @@ void ModulePrinter::printType(Type type) {
|
||||||
os << '>';
|
os << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
case StandardTypes::UnrankedMemRef: {
|
||||||
|
auto v = type.cast<UnrankedMemRefType>();
|
||||||
|
os << "memref<*x";
|
||||||
|
printType(v.getElementType());
|
||||||
|
os << '>';
|
||||||
|
return;
|
||||||
|
}
|
||||||
case StandardTypes::Complex:
|
case StandardTypes::Complex:
|
||||||
os << "complex<";
|
os << "complex<";
|
||||||
printType(type.cast<ComplexType>().getElementType());
|
printType(type.cast<ComplexType>().getElementType());
|
||||||
|
|
|
@ -90,8 +90,8 @@ struct BuiltinDialect : public Dialect {
|
||||||
UnknownLoc>();
|
UnknownLoc>();
|
||||||
|
|
||||||
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
|
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
|
||||||
MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
|
MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
|
||||||
UnrankedTensorType, VectorType>();
|
RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
|
||||||
|
|
||||||
// TODO: These operations should be moved to a different dialect when they
|
// TODO: These operations should be moved to a different dialect when they
|
||||||
// have been fully decoupled from the core.
|
// have been fully decoupled from the core.
|
||||||
|
|
|
@ -390,6 +390,37 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||||
|
|
||||||
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// UnrankedMemRefType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
||||||
|
unsigned memorySpace) {
|
||||||
|
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
|
||||||
|
elementType, memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
||||||
|
unsigned memorySpace,
|
||||||
|
Location location) {
|
||||||
|
return Base::getChecked(location, elementType.getContext(),
|
||||||
|
StandardTypes::UnrankedMemRef, elementType,
|
||||||
|
memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned UnrankedMemRefType::getMemorySpace() const {
|
||||||
|
return getImpl()->memorySpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
|
||||||
|
llvm::Optional<Location> loc, MLIRContext *context, Type elementType,
|
||||||
|
unsigned memorySpace) {
|
||||||
|
// Check that memref is formed from allowed types.
|
||||||
|
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
|
||||||
|
return emitOptionalError(*loc, "invalid memref element type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
/// Given MemRef `sizes` that are either static or dynamic, returns the
|
/// Given MemRef `sizes` that are either static or dynamic, returns the
|
||||||
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
|
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
|
||||||
/// once a dynamic dimension is encountered, all canonical strides become
|
/// once a dynamic dimension is encountered, all canonical strides become
|
||||||
|
|
|
@ -119,8 +119,8 @@ struct FunctionTypeStorage : public TypeStorage {
|
||||||
|
|
||||||
/// Shaped Type Storage.
|
/// Shaped Type Storage.
|
||||||
struct ShapedTypeStorage : public TypeStorage {
|
struct ShapedTypeStorage : public TypeStorage {
|
||||||
ShapedTypeStorage(Type elementType, unsigned subclassData = 0)
|
ShapedTypeStorage(Type elementTy, unsigned subclassData = 0)
|
||||||
: TypeStorage(subclassData), elementType(elementType) {}
|
: TypeStorage(subclassData), elementType(elementTy) {}
|
||||||
|
|
||||||
/// The hash key used for uniquing.
|
/// The hash key used for uniquing.
|
||||||
using KeyTy = Type;
|
using KeyTy = Type;
|
||||||
|
@ -252,6 +252,31 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
|
||||||
const unsigned memorySpace;
|
const unsigned memorySpace;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Unranked MemRef is a MemRef with unknown rank.
|
||||||
|
/// Only element type and memory space are known
|
||||||
|
struct UnrankedMemRefTypeStorage : public ShapedTypeStorage {
|
||||||
|
|
||||||
|
UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace)
|
||||||
|
: ShapedTypeStorage(elementTy), memorySpace(memorySpace) {}
|
||||||
|
|
||||||
|
/// The hash key used for uniquing.
|
||||||
|
using KeyTy = std::tuple<Type, unsigned>;
|
||||||
|
bool operator==(const KeyTy &key) const {
|
||||||
|
return key == KeyTy(elementType, memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construction.
|
||||||
|
static UnrankedMemRefTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||||
|
const KeyTy &key) {
|
||||||
|
|
||||||
|
// Initialize the memory using placement new.
|
||||||
|
return new (allocator.allocate<UnrankedMemRefTypeStorage>())
|
||||||
|
UnrankedMemRefTypeStorage(std::get<0>(key), std::get<1>(key));
|
||||||
|
}
|
||||||
|
/// Memory space in which data referenced by memref resides.
|
||||||
|
const unsigned memorySpace;
|
||||||
|
};
|
||||||
|
|
||||||
/// Complex Type Storage.
|
/// Complex Type Storage.
|
||||||
struct ComplexTypeStorage : public TypeStorage {
|
struct ComplexTypeStorage : public TypeStorage {
|
||||||
ComplexTypeStorage(Type elementType) : elementType(elementType) {}
|
ComplexTypeStorage(Type elementType) : elementType(elementType) {}
|
||||||
|
|
|
@ -1054,8 +1054,13 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
|
||||||
|
|
||||||
/// Parse a memref type.
|
/// Parse a memref type.
|
||||||
///
|
///
|
||||||
/// memref-type ::= `memref` `<` dimension-list-ranked type
|
/// memref-type ::= ranked-memref-type | unranked-memref-type
|
||||||
/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
|
///
|
||||||
|
/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
|
||||||
|
/// (`,` semi-affine-map-composition)? (`,`
|
||||||
|
/// memory-space)? `>`
|
||||||
|
///
|
||||||
|
/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
|
||||||
///
|
///
|
||||||
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
||||||
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
||||||
|
@ -1066,9 +1071,20 @@ Type Parser::parseMemRefType() {
|
||||||
if (parseToken(Token::less, "expected '<' in memref type"))
|
if (parseToken(Token::less, "expected '<' in memref type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
bool isUnranked;
|
||||||
SmallVector<int64_t, 4> dimensions;
|
SmallVector<int64_t, 4> dimensions;
|
||||||
|
|
||||||
|
if (consumeIf(Token::star)) {
|
||||||
|
// This is an unranked memref type.
|
||||||
|
isUnranked = true;
|
||||||
|
if (parseXInDimensionList())
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
isUnranked = false;
|
||||||
if (parseDimensionListRanked(dimensions))
|
if (parseDimensionListRanked(dimensions))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the element type.
|
// Parse the element type.
|
||||||
auto typeLoc = getToken().getLoc();
|
auto typeLoc = getToken().getLoc();
|
||||||
|
@ -1093,6 +1109,8 @@ Type Parser::parseMemRefType() {
|
||||||
consumeToken(Token::integer);
|
consumeToken(Token::integer);
|
||||||
parsedMemorySpace = true;
|
parsedMemorySpace = true;
|
||||||
} else {
|
} else {
|
||||||
|
if (isUnranked)
|
||||||
|
return emitError("cannot have affine map for unranked memref type");
|
||||||
if (parsedMemorySpace)
|
if (parsedMemorySpace)
|
||||||
return emitError("expected memory space to be last in memref type");
|
return emitError("expected memory space to be last in memref type");
|
||||||
if (getToken().is(Token::kw_offset)) {
|
if (getToken().is(Token::kw_offset)) {
|
||||||
|
@ -1131,6 +1149,10 @@ Type Parser::parseMemRefType() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isUnranked)
|
||||||
|
return UnrankedMemRefType::getChecked(elementType, memorySpace,
|
||||||
|
getEncodedSourceLocation(typeLoc));
|
||||||
|
|
||||||
return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
|
return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
|
||||||
memorySpace, getEncodedSourceLocation(typeLoc));
|
memorySpace, getEncodedSourceLocation(typeLoc));
|
||||||
}
|
}
|
||||||
|
|
|
@ -371,6 +371,30 @@ func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @memref_cast_ranked_to_unranked
|
||||||
|
func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) {
|
||||||
|
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||||
|
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||||
|
// CHECK-DAG: llvm.store %[[ld]], %[[p]] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">
|
||||||
|
// CHECK-DAG: %[[p2:.*]] = llvm.bitcast %2 : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> to !llvm<"i8*">
|
||||||
|
// CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
|
||||||
|
// CHECK : llvm.mlir.undef : !llvm<"{ i64, i8* }">
|
||||||
|
// CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm<"{ i64, i8* }">
|
||||||
|
// CHECK-DAG: llvm.insertvalue %[[p2]], %{{.*}}[1] : !llvm<"{ i64, i8* }">
|
||||||
|
%0 = memref_cast %arg : memref<42x2x?xf32> to memref<*xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @memref_cast_unranked_to_ranked
|
||||||
|
func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) {
|
||||||
|
// CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*">
|
||||||
|
// CHECK-NEXT: %[[p:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ i64, i8* }">
|
||||||
|
// CHECK-NEXT: llvm.bitcast %[[p]] : !llvm<"i8*"> to !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }*">
|
||||||
|
%0 = memref_cast %arg : memref<*xf32> to memref<?x?x10x2xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
|
// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) {
|
||||||
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
|
func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
|
||||||
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
|
// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">
|
||||||
|
|
|
@ -565,6 +565,11 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64
|
||||||
// CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]>
|
// CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]>
|
||||||
%3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>
|
%3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>
|
||||||
|
|
||||||
|
// CHECK: memref_cast %{{.*}} : memref<4xf32> to memref<*xf32>
|
||||||
|
%4 = memref_cast %1 : memref<4xf32> to memref<*xf32>
|
||||||
|
|
||||||
|
// CHECK: memref_cast %{{.*}} : memref<*xf32> to memref<4xf32>
|
||||||
|
%5 = memref_cast %4 : memref<*xf32> to memref<4xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -978,3 +978,34 @@ func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16,
|
||||||
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]>
|
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// incompatible element types
|
||||||
|
func @invalid_memref_cast() {
|
||||||
|
%0 = alloc() : memref<2x5xf32, 0>
|
||||||
|
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}}
|
||||||
|
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// incompatible memory space
|
||||||
|
func @invalid_memref_cast() {
|
||||||
|
%0 = alloc() : memref<2x5xf32, 0>
|
||||||
|
// expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32>' are cast incompatible}}
|
||||||
|
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// unranked to unranked
|
||||||
|
func @invalid_memref_cast() {
|
||||||
|
%0 = alloc() : memref<2x5xf32, 0>
|
||||||
|
%1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
|
||||||
|
// expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
|
||||||
|
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -78,6 +78,12 @@ template <typename T> struct StridedMemRefType<T, 0> {
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Unranked MemRef
|
||||||
|
struct UnrankedMemRefType {
|
||||||
|
int64_t rank;
|
||||||
|
void *descriptor;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename StreamType, typename T, int N>
|
template <typename StreamType, typename T, int N>
|
||||||
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
|
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
|
||||||
static_assert(N > 0, "Expected N > 0");
|
static_assert(N > 0, "Expected N > 0");
|
||||||
|
@ -97,6 +103,15 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
|
||||||
<< " offset = " << V.offset;
|
<< " offset = " << V.offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||||
|
print_memref_f32(UnrankedMemRefType *M);
|
||||||
|
|
||||||
|
template <typename StreamType>
|
||||||
|
void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) {
|
||||||
|
os << "Unranked Memref rank = " << V.rank << " "
|
||||||
|
<< "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||||
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
|
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
|
||||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||||
|
|
|
@ -148,15 +148,41 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> void printZeroDMemRef(StridedMemRefType<T, 0> &M) {
|
template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
|
||||||
std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
|
std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
|
||||||
<< " offset = " << M.offset << " data = [";
|
<< " offset = " << M.offset << " data = [";
|
||||||
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
|
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
|
||||||
std::cout << "]" << std::endl;
|
std::cout << "]" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" void
|
||||||
|
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
||||||
|
printMemRef(*M);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void print_memref_f32(UnrankedMemRefType *M) {
|
||||||
|
printUnrankedMemRefMetaData(std::cout, *M);
|
||||||
|
int rank = M->rank;
|
||||||
|
void *ptr = M->descriptor;
|
||||||
|
|
||||||
|
#define MEMREF_CASE(RANK) \
|
||||||
|
case RANK: \
|
||||||
|
printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr))); \
|
||||||
|
break
|
||||||
|
|
||||||
|
switch (rank) {
|
||||||
|
MEMREF_CASE(0);
|
||||||
|
MEMREF_CASE(1);
|
||||||
|
MEMREF_CASE(2);
|
||||||
|
MEMREF_CASE(3);
|
||||||
|
MEMREF_CASE(4);
|
||||||
|
default:
|
||||||
|
assert(0 && "Unsupported rank to print");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
|
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
|
||||||
printZeroDMemRef(*M);
|
printMemRef(*M);
|
||||||
}
|
}
|
||||||
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
|
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
|
||||||
printMemRef(*M);
|
printMemRef(*M);
|
||||||
|
@ -170,8 +196,3 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
|
||||||
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
||||||
printMemRef(*M);
|
printMemRef(*M);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void
|
|
||||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
|
|
||||||
printMemRef(*M);
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: rank = 2
|
||||||
|
// CHECK-SAME: sizes = [10, 3]
|
||||||
|
// CHECK-SAME: strides = [3, 1]
|
||||||
|
// CHECK-COUNT-10: [10, 10, 10]
|
||||||
|
//
|
||||||
|
// CHECK: rank = 2
|
||||||
|
// CHECK-SAME: sizes = [10, 3]
|
||||||
|
// CHECK-SAME: strides = [3, 1]
|
||||||
|
// CHECK-COUNT-10: [5, 5, 5]
|
||||||
|
//
|
||||||
|
// CHECK: rank = 2
|
||||||
|
// CHECK-SAME: sizes = [10, 3]
|
||||||
|
// CHECK-SAME: strides = [3, 1]
|
||||||
|
// CHECK-COUNT-10: [2, 2, 2]
|
||||||
|
func @main() -> () {
|
||||||
|
%A = alloc() : memref<10x3xf32, 0>
|
||||||
|
%f2 = constant 2.00000e+00 : f32
|
||||||
|
%f5 = constant 5.00000e+00 : f32
|
||||||
|
%f10 = constant 10.00000e+00 : f32
|
||||||
|
|
||||||
|
%V = memref_cast %A : memref<10x3xf32, 0> to memref<?x?xf32>
|
||||||
|
linalg.fill(%V, %f10) : memref<?x?xf32, 0>, f32
|
||||||
|
%U = memref_cast %A : memref<10x3xf32, 0> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||||
|
|
||||||
|
%V2 = memref_cast %U : memref<*xf32> to memref<?x?xf32>
|
||||||
|
linalg.fill(%V2, %f5) : memref<?x?xf32, 0>, f32
|
||||||
|
%U2 = memref_cast %V2 : memref<?x?xf32, 0> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U2) : (memref<*xf32>) -> ()
|
||||||
|
|
||||||
|
%V3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
|
||||||
|
%V4 = memref_cast %V3 : memref<*xf32> to memref<?x?xf32>
|
||||||
|
linalg.fill(%V4, %f2) : memref<?x?xf32, 0>, f32
|
||||||
|
%U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
|
||||||
|
|
||||||
|
dealloc %A : memref<10x3xf32, 0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @print_memref_f32(memref<*xf32>)
|
|
@ -7,7 +7,8 @@ func @print_0d() {
|
||||||
%f = constant 2.00000e+00 : f32
|
%f = constant 2.00000e+00 : f32
|
||||||
%A = alloc() : memref<f32>
|
%A = alloc() : memref<f32>
|
||||||
store %f, %A[]: memref<f32>
|
store %f, %A[]: memref<f32>
|
||||||
call @print_memref_0d_f32(%A): (memref<f32>) -> ()
|
%U = memref_cast %A : memref<f32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||||
dealloc %A : memref<f32>
|
dealloc %A : memref<f32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -18,7 +19,8 @@ func @print_1d() {
|
||||||
%A = alloc() : memref<16xf32>
|
%A = alloc() : memref<16xf32>
|
||||||
%B = memref_cast %A: memref<16xf32> to memref<?xf32>
|
%B = memref_cast %A: memref<16xf32> to memref<?xf32>
|
||||||
linalg.fill(%B, %f) : memref<?xf32>, f32
|
linalg.fill(%B, %f) : memref<?xf32>, f32
|
||||||
call @print_memref_1d_f32(%B): (memref<?xf32>) -> ()
|
%U = memref_cast %B : memref<?xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||||
dealloc %A : memref<16xf32>
|
dealloc %A : memref<16xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -34,8 +36,8 @@ func @print_3d() {
|
||||||
|
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>
|
store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>
|
||||||
|
%U = memref_cast %B : memref<?x?x?xf32> to memref<*xf32>
|
||||||
call @print_memref_3d_f32(%B): (memref<?x?x?xf32>) -> ()
|
call @print_memref_f32(%U): (memref<*xf32>) -> ()
|
||||||
dealloc %A : memref<3x4x5xf32>
|
dealloc %A : memref<3x4x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -46,10 +48,7 @@ func @print_3d() {
|
||||||
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
|
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
|
||||||
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
|
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
|
||||||
|
|
||||||
func @print_memref_0d_f32(memref<f32>)
|
func @print_memref_f32(memref<*xf32>)
|
||||||
func @print_memref_1d_f32(memref<?xf32>)
|
|
||||||
func @print_memref_3d_f32(memref<?x?x?xf32>)
|
|
||||||
|
|
||||||
|
|
||||||
!vector_type_C = type vector<4x4xf32>
|
!vector_type_C = type vector<4x4xf32>
|
||||||
!matrix_type_CC = type memref<1x1x!vector_type_C>
|
!matrix_type_CC = type memref<1x1x!vector_type_C>
|
||||||
|
|
|
@ -22,9 +22,10 @@ func @main() {
|
||||||
store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32>
|
store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32>
|
||||||
gpu.return
|
gpu.return
|
||||||
}
|
}
|
||||||
call @print_memref_3d_f32(%dst) : (memref<?x?x?xf32>) -> ()
|
%U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
|
func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
|
||||||
func @print_memref_3d_f32(%ptr : memref<?x?x?xf32>)
|
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||||
|
|
|
@ -20,9 +20,10 @@ func @main() {
|
||||||
store %res, %kernel_dst[%tx] : memref<?xf32>
|
store %res, %kernel_dst[%tx] : memref<?xf32>
|
||||||
gpu.return
|
gpu.return
|
||||||
}
|
}
|
||||||
call @print_memref_1d_f32(%dst) : (memref<?xf32>) -> ()
|
%U = memref_cast %dst : memref<?xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%U) : (memref<*xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
|
func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
|
||||||
func @print_memref_1d_f32(memref<?xf32>)
|
func @print_memref_f32(memref<*xf32>)
|
||||||
|
|
Loading…
Reference in New Issue