forked from OSchip/llvm-project
[mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS
These are among the last operations still defined explicitly in C++. I've tried to keep this commit as NFC as possible, but these ops definitely need a non-NFC cleanup at some point. Differential Revision: https://reviews.llvm.org/D110440
This commit is contained in:
parent
96cb97c453
commit
aca9bea199
|
@ -46,206 +46,4 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
|
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace memref {
|
|
||||||
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
|
|
||||||
// source memref to a destination memref. The source and destination memref need
|
|
||||||
// not be of the same dimensionality, but need to have the same elemental type.
|
|
||||||
// The operands include the source and destination memref's each followed by its
|
|
||||||
// indices, size of the data transfer in terms of the number of elements (of the
|
|
||||||
// elemental type of the memref), a tag memref with its indices, and optionally
|
|
||||||
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
|
|
||||||
// location is used by a DmaWaitOp to check for completion. The indices of the
|
|
||||||
// source memref, destination memref, and the tag memref have the same
|
|
||||||
// restrictions as any load/store. The optional stride arguments should be of
|
|
||||||
// 'index' type, and specify a stride for the slower memory space (memory space
|
|
||||||
// with a lower memory space id), transferring chunks of
|
|
||||||
// number_of_elements_per_stride every stride until %num_elements are
|
|
||||||
// transferred. Either both or no stride arguments should be specified. If the
|
|
||||||
// source and destination locations overlap the behavior of this operation is
|
|
||||||
// not defined.
|
|
||||||
//
|
|
||||||
// For example, a DmaStartOp operation that transfers 256 elements of a memref
|
|
||||||
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
|
|
||||||
// 1 at indices [%k, %l], would be specified as follows:
|
|
||||||
//
|
|
||||||
// %num_elements = constant 256
|
|
||||||
// %idx = constant 0 : index
|
|
||||||
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
|
|
||||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
|
|
||||||
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
|
|
||||||
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
|
|
||||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
|
||||||
//
|
|
||||||
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
|
||||||
// transfer %num_elt_per_stride elements every %stride elements apart from
|
|
||||||
// memory space 0 until %num_elements are transferred.
|
|
||||||
//
|
|
||||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
|
|
||||||
// %num_elt_per_stride :
|
|
||||||
//
|
|
||||||
// TODO: add additional operands to allow source and destination striding, and
|
|
||||||
// multiple stride levels.
|
|
||||||
// TODO: Consider replacing src/dst memref indices with view memrefs.
|
|
||||||
class DmaStartOp
|
|
||||||
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
|
||||||
public:
|
|
||||||
using Op::Op;
|
|
||||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
||||||
|
|
||||||
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
|
|
||||||
ValueRange srcIndices, Value destMemRef,
|
|
||||||
ValueRange destIndices, Value numElements, Value tagMemRef,
|
|
||||||
ValueRange tagIndices, Value stride = nullptr,
|
|
||||||
Value elementsPerStride = nullptr);
|
|
||||||
|
|
||||||
// Returns the source MemRefType for this DMA operation.
|
|
||||||
Value getSrcMemRef() { return getOperand(0); }
|
|
||||||
// Returns the rank (number of indices) of the source MemRefType.
|
|
||||||
unsigned getSrcMemRefRank() {
|
|
||||||
return getSrcMemRef().getType().cast<MemRefType>().getRank();
|
|
||||||
}
|
|
||||||
// Returns the source memref indices for this DMA operation.
|
|
||||||
operand_range getSrcIndices() {
|
|
||||||
return {(*this)->operand_begin() + 1,
|
|
||||||
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the destination MemRefType for this DMA operations.
|
|
||||||
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
|
|
||||||
// Returns the rank (number of indices) of the destination MemRefType.
|
|
||||||
unsigned getDstMemRefRank() {
|
|
||||||
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
|
||||||
}
|
|
||||||
unsigned getSrcMemorySpace() {
|
|
||||||
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
|
||||||
}
|
|
||||||
unsigned getDstMemorySpace() {
|
|
||||||
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the destination memref indices for this DMA operation.
|
|
||||||
operand_range getDstIndices() {
|
|
||||||
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
|
|
||||||
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
|
|
||||||
getDstMemRefRank()};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the number of elements being transferred by this DMA operation.
|
|
||||||
Value getNumElements() {
|
|
||||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the Tag MemRef for this DMA operation.
|
|
||||||
Value getTagMemRef() {
|
|
||||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
|
|
||||||
}
|
|
||||||
// Returns the rank (number of indices) of the tag MemRefType.
|
|
||||||
unsigned getTagMemRefRank() {
|
|
||||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the tag memref index for this DMA operation.
|
|
||||||
operand_range getTagIndices() {
|
|
||||||
unsigned tagIndexStartPos =
|
|
||||||
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
|
|
||||||
return {(*this)->operand_begin() + tagIndexStartPos,
|
|
||||||
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
|
||||||
bool isDestMemorySpaceFaster() {
|
|
||||||
return (getSrcMemorySpace() < getDstMemorySpace());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
|
||||||
bool isSrcMemorySpaceFaster() {
|
|
||||||
// Assumes that a lower number is for a slower memory space.
|
|
||||||
return (getDstMemorySpace() < getSrcMemorySpace());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Given a DMA start operation, returns the operand position of either the
|
|
||||||
/// source or destination memref depending on the one that is at the higher
|
|
||||||
/// level of the memory hierarchy. Asserts failure if neither is true.
|
|
||||||
unsigned getFasterMemPos() {
|
|
||||||
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
|
||||||
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
static StringRef getOperationName() { return "memref.dma_start"; }
|
|
||||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
|
||||||
void print(OpAsmPrinter &p);
|
|
||||||
LogicalResult verify();
|
|
||||||
|
|
||||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
|
||||||
SmallVectorImpl<OpFoldResult> &results);
|
|
||||||
|
|
||||||
bool isStrided() {
|
|
||||||
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
|
|
||||||
1 + 1 + getTagMemRefRank();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value getStride() {
|
|
||||||
if (!isStrided())
|
|
||||||
return nullptr;
|
|
||||||
return getOperand(getNumOperands() - 1 - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value getNumElementsPerStride() {
|
|
||||||
if (!isStrided())
|
|
||||||
return nullptr;
|
|
||||||
return getOperand(getNumOperands() - 1);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// DmaWaitOp blocks until the completion of a DMA operation associated with the
|
|
||||||
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
|
|
||||||
// with the same restrictions as any load/store index. %num_elements is the
|
|
||||||
// number of elements associated with the DMA operation. For example:
|
|
||||||
//
|
|
||||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
|
|
||||||
// memref<2048 x f32>, (d0) -> (d0), 0>,
|
|
||||||
// memref<256 x f32>, (d0) -> (d0), 1>
|
|
||||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
|
||||||
// ...
|
|
||||||
// ...
|
|
||||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
|
|
||||||
//
|
|
||||||
class DmaWaitOp
|
|
||||||
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
|
||||||
public:
|
|
||||||
using Op::Op;
|
|
||||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
||||||
|
|
||||||
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
|
|
||||||
ValueRange tagIndices, Value numElements);
|
|
||||||
|
|
||||||
static StringRef getOperationName() { return "memref.dma_wait"; }
|
|
||||||
|
|
||||||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
|
||||||
Value getTagMemRef() { return getOperand(0); }
|
|
||||||
|
|
||||||
// Returns the tag memref index for this DMA operation.
|
|
||||||
operand_range getTagIndices() {
|
|
||||||
return {(*this)->operand_begin() + 1,
|
|
||||||
(*this)->operand_begin() + 1 + getTagMemRefRank()};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the rank (number of indices) of the tag memref.
|
|
||||||
unsigned getTagMemRefRank() {
|
|
||||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the number of elements transferred in the associated DMA operation.
|
|
||||||
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
|
|
||||||
|
|
||||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
|
||||||
void print(OpAsmPrinter &p);
|
|
||||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
|
||||||
SmallVectorImpl<OpFoldResult> &results);
|
|
||||||
LogicalResult verify();
|
|
||||||
};
|
|
||||||
} // namespace memref
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
||||||
|
|
|
@ -284,8 +284,6 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BufferCastOp
|
// BufferCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -568,6 +566,217 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DmaStartOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
|
||||||
|
let summary = "non-blocking DMA operation that starts a transfer";
|
||||||
|
let description = [{
|
||||||
|
DmaStartOp starts a non-blocking DMA operation that transfers data from a
|
||||||
|
source memref to a destination memref. The source and destination memref
|
||||||
|
need not be of the same dimensionality, but need to have the same elemental
|
||||||
|
type. The operands include the source and destination memref's each followed
|
||||||
|
by its indices, size of the data transfer in terms of the number of elements
|
||||||
|
(of the elemental type of the memref), a tag memref with its indices, and
|
||||||
|
optionally at the end, a stride and a number_of_elements_per_stride
|
||||||
|
arguments. The tag location is used by a DmaWaitOp to check for completion.
|
||||||
|
The indices of the source memref, destination memref, and the tag memref
|
||||||
|
have the same restrictions as any load/store. The optional stride arguments
|
||||||
|
should be of 'index' type, and specify a stride for the slower memory space
|
||||||
|
(memory space with a lower memory space id), transferring chunks of
|
||||||
|
number_of_elements_per_stride every stride until %num_elements are
|
||||||
|
transferred. Either both or no stride arguments should be specified. If the
|
||||||
|
source and destination locations overlap the behavior of this operation is
|
||||||
|
not defined.
|
||||||
|
|
||||||
|
For example, a DmaStartOp operation that transfers 256 elements of a memref
|
||||||
|
'%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory
|
||||||
|
space 1 at indices [%k, %l], would be specified as follows:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%num_elements = constant 256
|
||||||
|
%idx = constant 0 : index
|
||||||
|
%tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
|
||||||
|
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
|
||||||
|
memref<40 x 128 x f32>, (d0) -> (d0), 0>,
|
||||||
|
memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
|
||||||
|
memref<1 x i32>, (d0) -> (d0), 2>
|
||||||
|
```
|
||||||
|
|
||||||
|
If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
||||||
|
transfer %num_elt_per_stride elements every %stride elements apart from
|
||||||
|
memory space 0 until %num_elements are transferred.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
|
||||||
|
%num_elt_per_stride :
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO: add additional operands to allow source and destination striding, and
|
||||||
|
multiple stride levels.
|
||||||
|
TODO: Consider replacing src/dst memref indices with view memrefs.
|
||||||
|
}];
|
||||||
|
let arguments = (ins Variadic<AnyType>:$operands);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$srcMemRef, "ValueRange":$srcIndices,
|
||||||
|
"Value":$destMemRef, "ValueRange":$destIndices,
|
||||||
|
"Value":$numElements, "Value":$tagMemRef,
|
||||||
|
"ValueRange":$tagIndices, CArg<"Value", "{}">:$stride,
|
||||||
|
CArg<"Value", "{}">:$elementsPerStride)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// Returns the source MemRefType for this DMA operation.
|
||||||
|
Value getSrcMemRef() { return getOperand(0); }
|
||||||
|
// Returns the rank (number of indices) of the source MemRefType.
|
||||||
|
unsigned getSrcMemRefRank() {
|
||||||
|
return getSrcMemRef().getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
// Returns the source memref indices for this DMA operation.
|
||||||
|
operand_range getSrcIndices() {
|
||||||
|
return {(*this)->operand_begin() + 1,
|
||||||
|
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the destination MemRefType for this DMA operations.
|
||||||
|
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
|
||||||
|
// Returns the rank (number of indices) of the destination MemRefType.
|
||||||
|
unsigned getDstMemRefRank() {
|
||||||
|
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
unsigned getSrcMemorySpace() {
|
||||||
|
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||||
|
}
|
||||||
|
unsigned getDstMemorySpace() {
|
||||||
|
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the destination memref indices for this DMA operation.
|
||||||
|
operand_range getDstIndices() {
|
||||||
|
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
|
||||||
|
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
|
||||||
|
getDstMemRefRank()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of elements being transferred by this DMA operation.
|
||||||
|
Value getNumElements() {
|
||||||
|
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the Tag MemRef for this DMA operation.
|
||||||
|
Value getTagMemRef() {
|
||||||
|
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
|
||||||
|
}
|
||||||
|
// Returns the rank (number of indices) of the tag MemRefType.
|
||||||
|
unsigned getTagMemRefRank() {
|
||||||
|
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the tag memref index for this DMA operation.
|
||||||
|
operand_range getTagIndices() {
|
||||||
|
unsigned tagIndexStartPos =
|
||||||
|
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
|
||||||
|
return {(*this)->operand_begin() + tagIndexStartPos,
|
||||||
|
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a DMA from a faster memory space to a slower
|
||||||
|
/// one.
|
||||||
|
bool isDestMemorySpaceFaster() {
|
||||||
|
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a DMA from a slower memory space to a faster
|
||||||
|
/// one.
|
||||||
|
bool isSrcMemorySpaceFaster() {
|
||||||
|
// Assumes that a lower number is for a slower memory space.
|
||||||
|
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a DMA start operation, returns the operand position of either the
|
||||||
|
/// source or destination memref depending on the one that is at the higher
|
||||||
|
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||||
|
unsigned getFasterMemPos() {
|
||||||
|
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||||
|
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isStrided() {
|
||||||
|
return getNumOperands() != 1 + getSrcMemRefRank() + 1 +
|
||||||
|
getDstMemRefRank() + 1 + 1 +
|
||||||
|
getTagMemRefRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getStride() {
|
||||||
|
if (!isStrided())
|
||||||
|
return nullptr;
|
||||||
|
return getOperand(getNumOperands() - 1 - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getNumElementsPerStride() {
|
||||||
|
if (!isStrided())
|
||||||
|
return nullptr;
|
||||||
|
return getOperand(getNumOperands() - 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DmaWaitOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
|
||||||
|
let summary = "blocking DMA operation that waits for transfer completion";
|
||||||
|
let description = [{
|
||||||
|
DmaWaitOp blocks until the completion of a DMA operation associated with the
|
||||||
|
tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
|
||||||
|
with the same restrictions as any load/store index. %num_elements is the
|
||||||
|
number of elements associated with the DMA operation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
|
||||||
|
memref<2048 x f32>, (d0) -> (d0), 0>,
|
||||||
|
memref<256 x f32>, (d0) -> (d0), 1>
|
||||||
|
memref<1 x i32>, (d0) -> (d0), 2>
|
||||||
|
...
|
||||||
|
...
|
||||||
|
dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
AnyMemRef:$tagMemRef,
|
||||||
|
Variadic<Index>:$tagIndices,
|
||||||
|
Index:$numElements
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$tagMemRef `[` $tagIndices `]` `,` $numElements attr-dict `:`
|
||||||
|
type($tagMemRef)
|
||||||
|
}];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
/// Returns the Tag MemRef associated with the DMA operation being waited
|
||||||
|
/// on.
|
||||||
|
Value getTagMemRef() { return tagMemRef(); }
|
||||||
|
|
||||||
|
/// Returns the tag memref index for this DMA operation.
|
||||||
|
operand_range getTagIndices() { return tagIndices(); }
|
||||||
|
|
||||||
|
/// Returns the rank (number of indices) of the tag memref.
|
||||||
|
unsigned getTagMemRefRank() {
|
||||||
|
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of elements transferred in the associated DMA
|
||||||
|
/// operation.
|
||||||
|
Value getNumElements() { return numElements(); }
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GetGlobalOp
|
// GetGlobalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -34,9 +34,9 @@ struct MemRefInlinerInterface : public DialectInlinerInterface {
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
void mlir::memref::MemRefDialect::initialize() {
|
void mlir::memref::MemRefDialect::initialize() {
|
||||||
addOperations<DmaStartOp, DmaWaitOp,
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
|
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addInterfaces<MemRefInlinerInterface>();
|
addInterfaces<MemRefInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -909,16 +909,17 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result,
|
||||||
result.addOperands({stride, elementsPerStride});
|
result.addOperands({stride, elementsPerStride});
|
||||||
}
|
}
|
||||||
|
|
||||||
void DmaStartOp::print(OpAsmPrinter &p) {
|
static void print(OpAsmPrinter &p, DmaStartOp op) {
|
||||||
p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
|
p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
|
||||||
<< getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
|
<< op.getDstMemRef() << '[' << op.getDstIndices() << "], "
|
||||||
<< ", " << getTagMemRef() << '[' << getTagIndices() << ']';
|
<< op.getNumElements() << ", " << op.getTagMemRef() << '['
|
||||||
if (isStrided())
|
<< op.getTagIndices() << ']';
|
||||||
p << ", " << getStride() << ", " << getNumElementsPerStride();
|
if (op.isStrided())
|
||||||
|
p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
|
||||||
|
|
||||||
p.printOptionalAttrDict((*this)->getAttrs());
|
p.printOptionalAttrDict(op->getAttrs());
|
||||||
p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
|
p << " : " << op.getSrcMemRef().getType() << ", "
|
||||||
<< ", " << getTagMemRef().getType();
|
<< op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse DmaStartOp.
|
// Parse DmaStartOp.
|
||||||
|
@ -929,7 +930,8 @@ void DmaStartOp::print(OpAsmPrinter &p) {
|
||||||
// memref<1024 x f32, 2>,
|
// memref<1024 x f32, 2>,
|
||||||
// memref<1 x i32>
|
// memref<1 x i32>
|
||||||
//
|
//
|
||||||
ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseDmaStartOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
OpAsmParser::OperandType srcMemRefInfo;
|
OpAsmParser::OperandType srcMemRefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
|
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
|
||||||
OpAsmParser::OperandType dstMemRefInfo;
|
OpAsmParser::OperandType dstMemRefInfo;
|
||||||
|
@ -989,66 +991,67 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DmaStartOp::verify() {
|
static LogicalResult verify(DmaStartOp op) {
|
||||||
unsigned numOperands = getNumOperands();
|
unsigned numOperands = op.getNumOperands();
|
||||||
|
|
||||||
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
|
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
|
||||||
// the number of elements.
|
// the number of elements.
|
||||||
if (numOperands < 4)
|
if (numOperands < 4)
|
||||||
return emitOpError("expected at least 4 operands");
|
return op.emitOpError("expected at least 4 operands");
|
||||||
|
|
||||||
// Check types of operands. The order of these calls is important: the later
|
// Check types of operands. The order of these calls is important: the later
|
||||||
// calls rely on some type properties to compute the operand position.
|
// calls rely on some type properties to compute the operand position.
|
||||||
// 1. Source memref.
|
// 1. Source memref.
|
||||||
if (!getSrcMemRef().getType().isa<MemRefType>())
|
if (!op.getSrcMemRef().getType().isa<MemRefType>())
|
||||||
return emitOpError("expected source to be of memref type");
|
return op.emitOpError("expected source to be of memref type");
|
||||||
if (numOperands < getSrcMemRefRank() + 4)
|
if (numOperands < op.getSrcMemRefRank() + 4)
|
||||||
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
|
return op.emitOpError()
|
||||||
<< " operands";
|
<< "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
|
||||||
if (!getSrcIndices().empty() &&
|
if (!op.getSrcIndices().empty() &&
|
||||||
!llvm::all_of(getSrcIndices().getTypes(),
|
!llvm::all_of(op.getSrcIndices().getTypes(),
|
||||||
[](Type t) { return t.isIndex(); }))
|
[](Type t) { return t.isIndex(); }))
|
||||||
return emitOpError("expected source indices to be of index type");
|
return op.emitOpError("expected source indices to be of index type");
|
||||||
|
|
||||||
// 2. Destination memref.
|
// 2. Destination memref.
|
||||||
if (!getDstMemRef().getType().isa<MemRefType>())
|
if (!op.getDstMemRef().getType().isa<MemRefType>())
|
||||||
return emitOpError("expected destination to be of memref type");
|
return op.emitOpError("expected destination to be of memref type");
|
||||||
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
|
unsigned numExpectedOperands =
|
||||||
|
op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
|
||||||
if (numOperands < numExpectedOperands)
|
if (numOperands < numExpectedOperands)
|
||||||
return emitOpError() << "expected at least " << numExpectedOperands
|
return op.emitOpError()
|
||||||
<< " operands";
|
<< "expected at least " << numExpectedOperands << " operands";
|
||||||
if (!getDstIndices().empty() &&
|
if (!op.getDstIndices().empty() &&
|
||||||
!llvm::all_of(getDstIndices().getTypes(),
|
!llvm::all_of(op.getDstIndices().getTypes(),
|
||||||
[](Type t) { return t.isIndex(); }))
|
[](Type t) { return t.isIndex(); }))
|
||||||
return emitOpError("expected destination indices to be of index type");
|
return op.emitOpError("expected destination indices to be of index type");
|
||||||
|
|
||||||
// 3. Number of elements.
|
// 3. Number of elements.
|
||||||
if (!getNumElements().getType().isIndex())
|
if (!op.getNumElements().getType().isIndex())
|
||||||
return emitOpError("expected num elements to be of index type");
|
return op.emitOpError("expected num elements to be of index type");
|
||||||
|
|
||||||
// 4. Tag memref.
|
// 4. Tag memref.
|
||||||
if (!getTagMemRef().getType().isa<MemRefType>())
|
if (!op.getTagMemRef().getType().isa<MemRefType>())
|
||||||
return emitOpError("expected tag to be of memref type");
|
return op.emitOpError("expected tag to be of memref type");
|
||||||
numExpectedOperands += getTagMemRefRank();
|
numExpectedOperands += op.getTagMemRefRank();
|
||||||
if (numOperands < numExpectedOperands)
|
if (numOperands < numExpectedOperands)
|
||||||
return emitOpError() << "expected at least " << numExpectedOperands
|
return op.emitOpError()
|
||||||
<< " operands";
|
<< "expected at least " << numExpectedOperands << " operands";
|
||||||
if (!getTagIndices().empty() &&
|
if (!op.getTagIndices().empty() &&
|
||||||
!llvm::all_of(getTagIndices().getTypes(),
|
!llvm::all_of(op.getTagIndices().getTypes(),
|
||||||
[](Type t) { return t.isIndex(); }))
|
[](Type t) { return t.isIndex(); }))
|
||||||
return emitOpError("expected tag indices to be of index type");
|
return op.emitOpError("expected tag indices to be of index type");
|
||||||
|
|
||||||
// Optional stride-related operands must be either both present or both
|
// Optional stride-related operands must be either both present or both
|
||||||
// absent.
|
// absent.
|
||||||
if (numOperands != numExpectedOperands &&
|
if (numOperands != numExpectedOperands &&
|
||||||
numOperands != numExpectedOperands + 2)
|
numOperands != numExpectedOperands + 2)
|
||||||
return emitOpError("incorrect number of operands");
|
return op.emitOpError("incorrect number of operands");
|
||||||
|
|
||||||
// 5. Strides.
|
// 5. Strides.
|
||||||
if (isStrided()) {
|
if (op.isStrided()) {
|
||||||
if (!getStride().getType().isIndex() ||
|
if (!op.getStride().getType().isIndex() ||
|
||||||
!getNumElementsPerStride().getType().isIndex())
|
!op.getNumElementsPerStride().getType().isIndex())
|
||||||
return emitOpError(
|
return op.emitOpError(
|
||||||
"expected stride and num elements per stride to be of type index");
|
"expected stride and num elements per stride to be of type index");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1065,74 +1068,20 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
|
||||||
// DmaWaitOp
|
// DmaWaitOp
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
|
|
||||||
Value tagMemRef, ValueRange tagIndices,
|
|
||||||
Value numElements) {
|
|
||||||
result.addOperands(tagMemRef);
|
|
||||||
result.addOperands(tagIndices);
|
|
||||||
result.addOperands(numElements);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DmaWaitOp::print(OpAsmPrinter &p) {
|
|
||||||
p << " " << getTagMemRef() << '[' << getTagIndices() << "], "
|
|
||||||
<< getNumElements();
|
|
||||||
p.printOptionalAttrDict((*this)->getAttrs());
|
|
||||||
p << " : " << getTagMemRef().getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse DmaWaitOp.
|
|
||||||
// Eg:
|
|
||||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
|
|
||||||
//
|
|
||||||
ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
OpAsmParser::OperandType tagMemrefInfo;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
|
||||||
Type type;
|
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
|
||||||
OpAsmParser::OperandType numElementsInfo;
|
|
||||||
|
|
||||||
// Parse tag memref, its indices, and dma size.
|
|
||||||
if (parser.parseOperand(tagMemrefInfo) ||
|
|
||||||
parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
|
|
||||||
parser.parseComma() || parser.parseOperand(numElementsInfo) ||
|
|
||||||
parser.parseColonType(type) ||
|
|
||||||
parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
|
|
||||||
parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
|
|
||||||
parser.resolveOperand(numElementsInfo, indexType, result.operands))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
/// dma_wait(memrefcast) -> dma_wait
|
/// dma_wait(memrefcast) -> dma_wait
|
||||||
return foldMemRefCast(*this);
|
return foldMemRefCast(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DmaWaitOp::verify() {
|
static LogicalResult verify(DmaWaitOp op) {
|
||||||
// Mandatory non-variadic operands are tag and the number of elements.
|
// Check that the number of tag indices matches the tagMemRef rank.
|
||||||
if (getNumOperands() < 2)
|
unsigned numTagIndices = op.tagIndices().size();
|
||||||
return emitOpError() << "expected at least 2 operands";
|
unsigned tagMemRefRank = op.getTagMemRefRank();
|
||||||
|
if (numTagIndices != tagMemRefRank)
|
||||||
// Check types of operands. The order of these calls is important: the later
|
return op.emitOpError() << "expected tagIndices to have the same number of "
|
||||||
// calls rely on some type properties to compute the operand position.
|
"elements as the tagMemRef rank, expected "
|
||||||
if (!getTagMemRef().getType().isa<MemRefType>())
|
<< tagMemRefRank << ", but got " << numTagIndices;
|
||||||
return emitOpError() << "expected tag to be of memref type";
|
|
||||||
|
|
||||||
if (getNumOperands() != 2 + getTagMemRefRank())
|
|
||||||
return emitOpError() << "expected " << 2 + getTagMemRefRank()
|
|
||||||
<< " operands";
|
|
||||||
|
|
||||||
if (!getTagIndices().empty() &&
|
|
||||||
!llvm::all_of(getTagIndices().getTypes(),
|
|
||||||
[](Type t) { return t.isIndex(); }))
|
|
||||||
return emitOpError() << "expected tag indices to be of index type";
|
|
||||||
|
|
||||||
if (!getNumElements().getType().isIndex())
|
|
||||||
return emitOpError()
|
|
||||||
<< "expected the number of elements to be of index type";
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,132 @@
|
||||||
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
|
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
|
||||||
|
|
||||||
|
func @dma_start_not_enough_operands() {
|
||||||
|
// expected-error@+1 {{expected at least 4 operands}}
|
||||||
|
"memref.dma_start"() : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
|
||||||
|
// expected-error@+1 {{expected source to be of memref type}}
|
||||||
|
memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_not_enough_operands_for_src(
|
||||||
|
%src: memref<2x2x2xf32>, %idx: index) {
|
||||||
|
// expected-error@+1 {{expected at least 7 operands}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_src_index_wrong_type(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>, %flt: f32) {
|
||||||
|
// expected-error@+1 {{expected source indices to be of index type}}
|
||||||
|
"memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
|
||||||
|
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
|
||||||
|
%mref = memref.alloc() : memref<8 x f32>
|
||||||
|
// expected-error@+1 {{expected destination to be of memref type}}
|
||||||
|
memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_not_enough_operands_for_dst(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>) {
|
||||||
|
// expected-error@+1 {{expected at least 7 operands}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_dst_index_wrong_type(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>, %flt: f32) {
|
||||||
|
// expected-error@+1 {{expected destination indices to be of index type}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_dst_index_wrong_type(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>, %flt: f32) {
|
||||||
|
// expected-error@+1 {{expected num elements to be of index type}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
|
||||||
|
%mref = memref.alloc() : memref<8 x f32>
|
||||||
|
// expected-error@+1 {{expected tag to be of memref type}}
|
||||||
|
memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_not_enough_operands_for_tag(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<2xi32,2>) {
|
||||||
|
// expected-error@+1 {{expected at least 8 operands}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_dst_index_wrong_type(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<2xi32,2>, %flt: f32) {
|
||||||
|
// expected-error@+1 {{expected tag indices to be of index type}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_too_many_operands(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>) {
|
||||||
|
// expected-error@+1 {{incorrect number of operands}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_start_wrong_stride_type(
|
||||||
|
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||||
|
%tag: memref<i32,2>, %flt: f32) {
|
||||||
|
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
|
||||||
|
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
|
||||||
|
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt: index) {
|
||||||
|
// expected-error@+1 {{expected tagIndices to have the same number of elements as the tagMemRef rank, expected 2, but got 1}}
|
||||||
|
"memref.dma_wait"(%tag, %flt, %idx) : (memref<2x2xi32>, index, index) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
|
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
|
||||||
// expected-error @+1 {{expected a permutation map}}
|
// expected-error @+1 {{expected a permutation map}}
|
||||||
memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
|
memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
|
||||||
|
|
|
@ -290,153 +290,6 @@ func @invalid_cmp_shape(%idx : () -> ()) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dma_start_not_enough_operands() {
|
|
||||||
// expected-error@+1 {{expected at least 4 operands}}
|
|
||||||
"memref.dma_start"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
|
|
||||||
// expected-error@+1 {{expected source to be of memref type}}
|
|
||||||
memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_not_enough_operands_for_src(
|
|
||||||
%src: memref<2x2x2xf32>, %idx: index) {
|
|
||||||
// expected-error@+1 {{expected at least 7 operands}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_src_index_wrong_type(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected source indices to be of index type}}
|
|
||||||
"memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
|
|
||||||
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
|
|
||||||
%mref = memref.alloc() : memref<8 x f32>
|
|
||||||
// expected-error@+1 {{expected destination to be of memref type}}
|
|
||||||
memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_not_enough_operands_for_dst(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>) {
|
|
||||||
// expected-error@+1 {{expected at least 7 operands}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_dst_index_wrong_type(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected destination indices to be of index type}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_dst_index_wrong_type(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected num elements to be of index type}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
|
|
||||||
%mref = memref.alloc() : memref<8 x f32>
|
|
||||||
// expected-error@+1 {{expected tag to be of memref type}}
|
|
||||||
memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_not_enough_operands_for_tag(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<2xi32,2>) {
|
|
||||||
// expected-error@+1 {{expected at least 8 operands}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_dst_index_wrong_type(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<2xi32,2>, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected tag indices to be of index type}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_too_many_operands(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>) {
|
|
||||||
// expected-error@+1 {{incorrect number of operands}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_start_wrong_stride_type(
|
|
||||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
|
||||||
%tag: memref<i32,2>, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
|
|
||||||
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
|
|
||||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_wait_not_enough_operands() {
|
|
||||||
// expected-error@+1 {{expected at least 2 operands}}
|
|
||||||
"memref.dma_wait"() : () -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
|
|
||||||
// expected-error@+1 {{expected tag to be of memref type}}
|
|
||||||
"memref.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected tag indices to be of index type}}
|
|
||||||
"memref.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
|
|
||||||
// expected-error@+1 {{expected the number of elements to be of index type}}
|
|
||||||
"memref.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @invalid_cmp_attr(%idx : i32) {
|
func @invalid_cmp_attr(%idx : i32) {
|
||||||
// expected-error@+1 {{expected string or keyword containing one of the following enum values}}
|
// expected-error@+1 {{expected string or keyword containing one of the following enum values}}
|
||||||
%cmp = cmpi i1, %idx, %idx : i32
|
%cmp = cmpi i1, %idx, %idx : i32
|
||||||
|
|
Loading…
Reference in New Issue