[mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS

These are among the last operations still defined explicitly in C++. I've
tried to keep this commit as NFC as possible, but these ops
definitely need a non-NFC cleanup at some point.

Differential Revision: https://reviews.llvm.org/D110440
This commit is contained in:
River Riddle 2021-09-24 19:32:23 +00:00
parent 96cb97c453
commit aca9bea199
6 changed files with 395 additions and 459 deletions

View File

@ -46,206 +46,4 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc" #include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
namespace mlir {
namespace memref {
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
// source memref to a destination memref. The source and destination memref need
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
// elemental type of the memref), a tag memref with its indices, and optionally
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
// restrictions as any load/store. The optional stride arguments should be of
// 'index' type, and specify a stride for the slower memory space (memory space
// with a lower memory space id), transferring chunks of
// number_of_elements_per_stride every stride until %num_elements are
// transferred. Either both or no stride arguments should be specified. If the
// source and destination locations overlap the behavior of this operation is
// not defined.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// 1 at indices [%k, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
// %num_elt_per_stride :
//
// TODO: add additional operands to allow source and destination striding, and
// multiple stride levels.
// TODO: Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
ValueRange destIndices, Value numElements, Value tagMemRef,
ValueRange tagIndices, Value stride = nullptr,
Value elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the source memref indices for this DMA operation.
operand_range getSrcIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
unsigned getDstMemorySpace() {
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {(*this)->operand_begin() + tagIndexStartPos,
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "memref.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank();
}
Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index. %num_elements is the
// number of elements associated with the DMA operation. For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
// memref<2048 x f32>, (d0) -> (d0), 0>,
// memref<256 x f32>, (d0) -> (d0), 1>
// memref<1 x i32>, (d0) -> (d0), 2>
// ...
// ...
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "memref.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
LogicalResult verify();
};
} // namespace memref
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_

View File

@ -284,8 +284,6 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
let verifier = ?; let verifier = ?;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BufferCastOp // BufferCastOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -568,6 +566,217 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
let hasFolder = 1; let hasFolder = 1;
} }
//===----------------------------------------------------------------------===//
// DmaStartOp
//===----------------------------------------------------------------------===//
def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
let summary = "non-blocking DMA operation that starts a transfer";
let description = [{
DmaStartOp starts a non-blocking DMA operation that transfers data from a
source memref to a destination memref. The source and destination memref
need not be of the same dimensionality, but need to have the same elemental
type. The operands include the source and destination memref's each followed
by its indices, size of the data transfer in terms of the number of elements
(of the elemental type of the memref), a tag memref with its indices, and
optionally at the end, a stride and a number_of_elements_per_stride
arguments. The tag location is used by a DmaWaitOp to check for completion.
The indices of the source memref, destination memref, and the tag memref
have the same restrictions as any load/store. The optional stride arguments
should be of 'index' type, and specify a stride for the slower memory space
(memory space with a lower memory space id), transferring chunks of
number_of_elements_per_stride every stride until %num_elements are
transferred. Either both or no stride arguments should be specified. If the
source and destination locations overlap the behavior of this operation is
not defined.
For example, a DmaStartOp operation that transfers 256 elements of a memref
'%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory
space 1 at indices [%k, %l], would be specified as follows:
```mlir
%num_elements = constant 256
%idx = constant 0 : index
%tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
memref<40 x 128 x f32>, (d0) -> (d0), 0>,
memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
memref<1 x i32>, (d0) -> (d0), 2>
```
If %stride and %num_elt_per_stride are specified, the DMA is expected to
transfer %num_elt_per_stride elements every %stride elements apart from
memory space 0 until %num_elements are transferred.
```mlir
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
%num_elt_per_stride :
```
TODO: add additional operands to allow source and destination striding, and
multiple stride levels.
TODO: Consider replacing src/dst memref indices with view memrefs.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [
OpBuilder<(ins "Value":$srcMemRef, "ValueRange":$srcIndices,
"Value":$destMemRef, "ValueRange":$destIndices,
"Value":$numElements, "Value":$tagMemRef,
"ValueRange":$tagIndices, CArg<"Value", "{}">:$stride,
CArg<"Value", "{}">:$elementsPerStride)>
];
let extraClassDeclaration = [{
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the source memref indices for this DMA operation.
operand_range getSrcIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
unsigned getDstMemorySpace() {
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {(*this)->operand_begin() + tagIndexStartPos,
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower
/// one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster
/// one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank() + 1 + 1 +
getTagMemRefRank();
}
Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// DmaWaitOp
//===----------------------------------------------------------------------===//
def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
let summary = "blocking DMA operation that waits for transfer completion";
let description = [{
DmaWaitOp blocks until the completion of a DMA operation associated with the
tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
with the same restrictions as any load/store index. %num_elements is the
number of elements associated with the DMA operation.
Example:
```mlir
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
memref<2048 x f32>, (d0) -> (d0), 0>,
memref<256 x f32>, (d0) -> (d0), 1>
memref<1 x i32>, (d0) -> (d0), 2>
...
...
dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
```
}];
let arguments = (ins
AnyMemRef:$tagMemRef,
Variadic<Index>:$tagIndices,
Index:$numElements
);
let assemblyFormat = [{
$tagMemRef `[` $tagIndices `]` `,` $numElements attr-dict `:`
type($tagMemRef)
}];
let extraClassDeclaration = [{
/// Returns the Tag MemRef associated with the DMA operation being waited
/// on.
Value getTagMemRef() { return tagMemRef(); }
/// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() { return tagIndices(); }
/// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
/// Returns the number of elements transferred in the associated DMA
/// operation.
Value getNumElements() { return numElements(); }
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GetGlobalOp // GetGlobalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -34,9 +34,9 @@ struct MemRefInlinerInterface : public DialectInlinerInterface {
} // end anonymous namespace } // end anonymous namespace
void mlir::memref::MemRefDialect::initialize() { void mlir::memref::MemRefDialect::initialize() {
addOperations<DmaStartOp, DmaWaitOp, addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
>(); >();
addInterfaces<MemRefInlinerInterface>(); addInterfaces<MemRefInlinerInterface>();
} }

View File

@ -909,16 +909,17 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result,
result.addOperands({stride, elementsPerStride}); result.addOperands({stride, elementsPerStride});
} }
void DmaStartOp::print(OpAsmPrinter &p) { static void print(OpAsmPrinter &p, DmaStartOp op) {
p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
<< getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
<< ", " << getTagMemRef() << '[' << getTagIndices() << ']'; << op.getNumElements() << ", " << op.getTagMemRef() << '['
if (isStrided()) << op.getTagIndices() << ']';
p << ", " << getStride() << ", " << getNumElementsPerStride(); if (op.isStrided())
p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
p.printOptionalAttrDict((*this)->getAttrs()); p.printOptionalAttrDict(op->getAttrs());
p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() p << " : " << op.getSrcMemRef().getType() << ", "
<< ", " << getTagMemRef().getType(); << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
} }
// Parse DmaStartOp. // Parse DmaStartOp.
@ -929,7 +930,8 @@ void DmaStartOp::print(OpAsmPrinter &p) {
// memref<1024 x f32, 2>, // memref<1024 x f32, 2>,
// memref<1 x i32> // memref<1 x i32>
// //
ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { static ParseResult parseDmaStartOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcMemRefInfo; OpAsmParser::OperandType srcMemRefInfo;
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
OpAsmParser::OperandType dstMemRefInfo; OpAsmParser::OperandType dstMemRefInfo;
@ -989,66 +991,67 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
return success(); return success();
} }
LogicalResult DmaStartOp::verify() { static LogicalResult verify(DmaStartOp op) {
unsigned numOperands = getNumOperands(); unsigned numOperands = op.getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements. // the number of elements.
if (numOperands < 4) if (numOperands < 4)
return emitOpError("expected at least 4 operands"); return op.emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later // Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position. // calls rely on some type properties to compute the operand position.
// 1. Source memref. // 1. Source memref.
if (!getSrcMemRef().getType().isa<MemRefType>()) if (!op.getSrcMemRef().getType().isa<MemRefType>())
return emitOpError("expected source to be of memref type"); return op.emitOpError("expected source to be of memref type");
if (numOperands < getSrcMemRefRank() + 4) if (numOperands < op.getSrcMemRefRank() + 4)
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 return op.emitOpError()
<< " operands"; << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
if (!getSrcIndices().empty() && if (!op.getSrcIndices().empty() &&
!llvm::all_of(getSrcIndices().getTypes(), !llvm::all_of(op.getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return emitOpError("expected source indices to be of index type"); return op.emitOpError("expected source indices to be of index type");
// 2. Destination memref. // 2. Destination memref.
if (!getDstMemRef().getType().isa<MemRefType>()) if (!op.getDstMemRef().getType().isa<MemRefType>())
return emitOpError("expected destination to be of memref type"); return op.emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; unsigned numExpectedOperands =
op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands) if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands return op.emitOpError()
<< " operands"; << "expected at least " << numExpectedOperands << " operands";
if (!getDstIndices().empty() && if (!op.getDstIndices().empty() &&
!llvm::all_of(getDstIndices().getTypes(), !llvm::all_of(op.getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return emitOpError("expected destination indices to be of index type"); return op.emitOpError("expected destination indices to be of index type");
// 3. Number of elements. // 3. Number of elements.
if (!getNumElements().getType().isIndex()) if (!op.getNumElements().getType().isIndex())
return emitOpError("expected num elements to be of index type"); return op.emitOpError("expected num elements to be of index type");
// 4. Tag memref. // 4. Tag memref.
if (!getTagMemRef().getType().isa<MemRefType>()) if (!op.getTagMemRef().getType().isa<MemRefType>())
return emitOpError("expected tag to be of memref type"); return op.emitOpError("expected tag to be of memref type");
numExpectedOperands += getTagMemRefRank(); numExpectedOperands += op.getTagMemRefRank();
if (numOperands < numExpectedOperands) if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands return op.emitOpError()
<< " operands"; << "expected at least " << numExpectedOperands << " operands";
if (!getTagIndices().empty() && if (!op.getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(), !llvm::all_of(op.getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return emitOpError("expected tag indices to be of index type"); return op.emitOpError("expected tag indices to be of index type");
// Optional stride-related operands must be either both present or both // Optional stride-related operands must be either both present or both
// absent. // absent.
if (numOperands != numExpectedOperands && if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2) numOperands != numExpectedOperands + 2)
return emitOpError("incorrect number of operands"); return op.emitOpError("incorrect number of operands");
// 5. Strides. // 5. Strides.
if (isStrided()) { if (op.isStrided()) {
if (!getStride().getType().isIndex() || if (!op.getStride().getType().isIndex() ||
!getNumElementsPerStride().getType().isIndex()) !op.getNumElementsPerStride().getType().isIndex())
return emitOpError( return op.emitOpError(
"expected stride and num elements per stride to be of type index"); "expected stride and num elements per stride to be of type index");
} }
@ -1065,74 +1068,20 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
// DmaWaitOp // DmaWaitOp
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
Value tagMemRef, ValueRange tagIndices,
Value numElements) {
result.addOperands(tagMemRef);
result.addOperands(tagIndices);
result.addOperands(numElements);
}
void DmaWaitOp::print(OpAsmPrinter &p) {
p << " " << getTagMemRef() << '[' << getTagIndices() << "], "
<< getNumElements();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getTagMemRef().getType();
}
// Parse DmaWaitOp.
// Eg:
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
//
ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type type;
auto indexType = parser.getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its indices, and dma size.
if (parser.parseOperand(tagMemrefInfo) ||
parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
parser.parseComma() || parser.parseOperand(numElementsInfo) ||
parser.parseColonType(type) ||
parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();
return success();
}
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) { SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait /// dma_wait(memrefcast) -> dma_wait
return foldMemRefCast(*this); return foldMemRefCast(*this);
} }
LogicalResult DmaWaitOp::verify() { static LogicalResult verify(DmaWaitOp op) {
// Mandatory non-variadic operands are tag and the number of elements. // Check that the number of tag indices matches the tagMemRef rank.
if (getNumOperands() < 2) unsigned numTagIndices = op.tagIndices().size();
return emitOpError() << "expected at least 2 operands"; unsigned tagMemRefRank = op.getTagMemRefRank();
if (numTagIndices != tagMemRefRank)
// Check types of operands. The order of these calls is important: the later return op.emitOpError() << "expected tagIndices to have the same number of "
// calls rely on some type properties to compute the operand position. "elements as the tagMemRef rank, expected "
if (!getTagMemRef().getType().isa<MemRefType>()) << tagMemRefRank << ", but got " << numTagIndices;
return emitOpError() << "expected tag to be of memref type";
if (getNumOperands() != 2 + getTagMemRefRank())
return emitOpError() << "expected " << 2 + getTagMemRefRank()
<< " operands";
if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError() << "expected tag indices to be of index type";
if (!getNumElements().getType().isIndex())
return emitOpError()
<< "expected the number of elements to be of index type";
return success(); return success();
} }

View File

@ -1,5 +1,132 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics // RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @dma_start_not_enough_operands() {
// expected-error@+1 {{expected at least 4 operands}}
"memref.dma_start"() : () -> ()
}
// -----
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
// expected-error@+1 {{expected source to be of memref type}}
memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
}
// -----
func @dma_start_not_enough_operands_for_src(
%src: memref<2x2x2xf32>, %idx: index) {
// expected-error@+1 {{expected at least 7 operands}}
"memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
}
// -----
func @dma_start_src_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected source indices to be of index type}}
"memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
}
// -----
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
%mref = memref.alloc() : memref<8 x f32>
// expected-error@+1 {{expected destination to be of memref type}}
memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
}
// -----
func @dma_start_not_enough_operands_for_dst(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{expected at least 7 operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected destination indices to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected num elements to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
}
// -----
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
%mref = memref.alloc() : memref<8 x f32>
// expected-error@+1 {{expected tag to be of memref type}}
memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
}
// -----
func @dma_start_not_enough_operands_for_tag(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>) {
// expected-error@+1 {{expected at least 8 operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
}
// -----
func @dma_start_too_many_operands(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{incorrect number of operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
}
// -----
func @dma_start_wrong_stride_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
}
// -----
func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt: index) {
// expected-error@+1 {{expected tagIndices to have the same number of elements as the tagMemRef rank, expected 2, but got 1}}
"memref.dma_wait"(%tag, %flt, %idx) : (memref<2x2xi32>, index, index) -> ()
return
}
// -----
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) { func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}} // expected-error @+1 {{expected a permutation map}}
memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>

View File

@ -290,153 +290,6 @@ func @invalid_cmp_shape(%idx : () -> ()) {
// ----- // -----
func @dma_start_not_enough_operands() {
// expected-error@+1 {{expected at least 4 operands}}
"memref.dma_start"() : () -> ()
}
// -----
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
// expected-error@+1 {{expected source to be of memref type}}
memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
}
// -----
func @dma_start_not_enough_operands_for_src(
%src: memref<2x2x2xf32>, %idx: index) {
// expected-error@+1 {{expected at least 7 operands}}
"memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
}
// -----
func @dma_start_src_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected source indices to be of index type}}
"memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
}
// -----
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
%mref = memref.alloc() : memref<8 x f32>
// expected-error@+1 {{expected destination to be of memref type}}
memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
}
// -----
func @dma_start_not_enough_operands_for_dst(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{expected at least 7 operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected destination indices to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected num elements to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
}
// -----
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
%mref = memref.alloc() : memref<8 x f32>
// expected-error@+1 {{expected tag to be of memref type}}
memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
}
// -----
func @dma_start_not_enough_operands_for_tag(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>) {
// expected-error@+1 {{expected at least 8 operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
}
// -----
func @dma_start_too_many_operands(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{incorrect number of operands}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
}
// -----
func @dma_start_wrong_stride_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
"memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
}
// -----
func @dma_wait_not_enough_operands() {
// expected-error@+1 {{expected at least 2 operands}}
"memref.dma_wait"() : () -> ()
}
// -----
func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
// expected-error@+1 {{expected tag to be of memref type}}
"memref.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
}
// -----
func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"memref.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
}
// -----
func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected the number of elements to be of index type}}
"memref.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
}
// -----
func @invalid_cmp_attr(%idx : i32) { func @invalid_cmp_attr(%idx : i32) {
// expected-error@+1 {{expected string or keyword containing one of the following enum values}} // expected-error@+1 {{expected string or keyword containing one of the following enum values}}
%cmp = cmpi i1, %idx, %idx : i32 %cmp = cmpi i1, %idx, %idx : i32