Add new operations affine.dma_start and affine.dma_wait which take affine maps for indexing memrefs by construction.

These ops are analogues of the current standard ops dma_start/wait, with the exception that the memref operands are affine expressions of loop IVs and symbols (analogous to affine.load/store).
The addition of these operations will enable changes to affine transformation and analysis passes which operate on memory dereferencing operations.

PiperOrigin-RevId: 255658382
This commit is contained in:
Andy Davis 2019-06-28 13:31:31 -07:00 committed by A. Unique TensorFlower
parent 7b17f4e647
commit 6c68596aee
3 changed files with 644 additions and 6 deletions

View File

@ -90,6 +90,285 @@ public:
MLIRContext *context);
};
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
/// from a source memref to a destination memref. The source and destination
/// memref need not be of the same dimensionality, but need to have the same
/// elemental type. The operands include the source and destination memref's
/// each followed by its indices, size of the data transfer in terms of the
/// number of elements (of the elemental type of the memref), a tag memref with
/// its indices, and optionally at the end, a stride and a
/// number_of_elements_per_stride arguments. The tag location is used by an
/// AffineDmaWaitOp to check for completion. The indices of the source memref,
/// destination memref, and the tag memref have the same restrictions as any
/// affine.load/store. In particular, index for each memref dimension must be an
/// affine expression of loop induction variables and symbols.
/// The optional stride arguments should be of 'index' type, and specify a
/// stride for the slower memory space (memory space with a lower memory space
/// id), tranferring chunks of number_of_elements_per_stride every stride until
/// %num_elements are transferred. Either both or no stride arguments should be
/// specified. The value of 'num_elements' must be a multiple of
/// 'number_of_elements_per_stride'.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
// space 1 at indices [%k + 7, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1xi32, 4>
// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
// %num_elements :
// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
// %stride, %num_elt_per_stride : ...
//
// TODO(mlir-team): add additional operands to allow source and destination
// striding, and multiple stride levels (possibly using AffineMaps to specify
// multiple levels of striding).
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
AffineMap srcMap, ArrayRef<Value *> srcIndices,
Value *destMemRef, AffineMap dstMap,
ArrayRef<Value *> destIndices, Value *tagMemRef,
AffineMap tagMap, ArrayRef<Value *> tagIndices,
Value *numElements, Value *stride = nullptr,
Value *elementsPerStride = nullptr);
/// Returns the operand index of the src memref.
unsigned getSrcMemRefOperandIndex() { return 0; }
/// Returns the source MemRefType for this DMA operation.
Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
MemRefType getSrcMemRefType() {
return getSrcMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
/// Returns the affine map used to access the src memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the source memref affine map indices for this DMA operation.
operand_range getSrcIndices() {
return {operand_begin() + getSrcMemRefOperandIndex() + 1,
operand_begin() + getSrcMemRefOperandIndex() + 1 +
getSrcMap().getNumInputs()};
}
/// Returns the memory space of the src memref.
unsigned getSrcMemorySpace() {
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
/// Returns the operand index of the dst memref.
unsigned getDstMemRefOperandIndex() {
return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
}
/// Returns the destination MemRefType for this DMA operations.
Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
MemRefType getDstMemRefType() {
return getDstMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the memory space of the src memref.
unsigned getDstMemorySpace() {
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
/// Returns the affine map used to access the dst memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {operand_begin() + getDstMemRefOperandIndex() + 1,
operand_begin() + getDstMemRefOperandIndex() + 1 +
getDstMap().getNumInputs()};
}
/// Returns the operand index of the tag memref.
unsigned getTagMemRefOperandIndex() {
return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
}
/// Returns the Tag MemRef for this DMA operation.
Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
/// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the tag memref indices for this DMA operation.
operand_range getTagIndices() {
return {operand_begin() + getTagMemRefOperandIndex() + 1,
operand_begin() + getTagMemRefOperandIndex() + 1 +
getTagMap().getNumInputs()};
}
/// Returns the number of elements being transferred by this DMA operation.
Value *getNumElements() {
return getOperand(getTagMemRefOperandIndex() + 1 +
getTagMap().getNumInputs());
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
if (memref == getSrcMemRef())
return {Identifier::get(getSrcMapAttrName(), getContext()),
getSrcMapAttr()};
else if (memref == getDstMemRef())
return {Identifier::get(getDstMapAttrName(), getContext()),
getDstMapAttr()};
assert(memref == getTagMemRef() &&
"DmaStartOp expected source, destination or tag memref");
return {Identifier::get(getTagMapAttrName(), getContext()),
getTagMapAttr()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
}
static StringRef getSrcMapAttrName() { return "src_map"; }
static StringRef getDstMapAttrName() { return "dst_map"; }
static StringRef getTagMapAttrName() { return "tag_map"; }
static StringRef getOperationName() { return "affine.dma_start"; }
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
/// Returns true if this DMA operation is strided, returns false otherwise.
bool isStrided() {
return getNumOperands() !=
getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
}
/// Returns the stride value for this DMA operation.
Value *getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
/// Returns the number of elements to transfer per stride for this DMA op.
Value *getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
/// an index with the same restrictions as any load/store index. In particular,
/// index for each memref dimension must be an affine expression of loop
/// induction variables and symbols. %num_elements is the number of elements
/// associated with the DMA operation. For example:
//
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
// ...
// ...
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
//
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
AffineMap tagMap, ArrayRef<Value *> tagIndices,
Value *numElements);
static StringRef getOperationName() { return "affine.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value *getTagMemRef() { return getOperand(0); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {operand_begin() + 1,
operand_begin() + 1 + getTagMap().getNumInputs()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getTagMemRef());
return {Identifier::get(getTagMapAttrName(), getContext()),
getTagMapAttr()};
}
/// Returns the number of elements transferred in the associated DMA op.
Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
static StringRef getTagMapAttrName() { return "tag_map"; }
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "affine.for" operation represents an affine loop nest, defining an SSA
/// value for its induction variable. It has one region capturing the loop body.
/// The induction variable is represented as a argument of this region. This SSA
@ -382,10 +661,19 @@ public:
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() {
return getAttrOfType<AffineMapAttr>("map").getValue();
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.load"; }
// Hooks to customize behavior of this op.
@ -435,10 +723,19 @@ public:
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() {
return getAttrOfType<AffineMapAttr>("map").getValue();
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
AffineMapAttr getAffineMapAttr() {
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
}
/// Returns the AffineMapAttr associated with 'memref'.
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
}
static StringRef getMapAttrName() { return "map"; }
static StringRef getOperationName() { return "affine.store"; }
// Hooks to customize behavior of this op.

View File

@ -37,8 +37,8 @@ using llvm::dbgs;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineLoadOp,
AffineStoreOp, AffineTerminatorOp>();
addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineForOp,
AffineIfOp, AffineLoadOp, AffineStoreOp, AffineTerminatorOp>();
}
/// A utility function to check if a value is defined at the top level of a
@ -696,6 +696,220 @@ void AffineApplyOp::getCanonicalizationPatterns(
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
}
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
//===----------------------------------------------------------------------===//
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
void AffineDmaStartOp::build(Builder *builder, OperationState *result,
Value *srcMemRef, AffineMap srcMap,
ArrayRef<Value *> srcIndices, Value *destMemRef,
AffineMap dstMap, ArrayRef<Value *> destIndices,
Value *tagMemRef, AffineMap tagMap,
ArrayRef<Value *> tagIndices, Value *numElements,
Value *stride, Value *elementsPerStride) {
result->addOperands(srcMemRef);
result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
result->addOperands(srcIndices);
result->addOperands(destMemRef);
result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
result->addOperands(destIndices);
result->addOperands(tagMemRef);
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
result->addOperands(tagIndices);
result->addOperands(numElements);
if (stride) {
result->addOperands({stride, elementsPerStride});
}
}
void AffineDmaStartOp::print(OpAsmPrinter *p) {
*p << "affine.dma_start " << *getSrcMemRef() << '[';
SmallVector<Value *, 8> operands(getSrcIndices());
p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
*p << "], " << *getDstMemRef() << '[';
operands.assign(getDstIndices().begin(), getDstIndices().end());
p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
*p << "], " << *getTagMemRef() << '[';
operands.assign(getTagIndices().begin(), getTagIndices().end());
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
*p << "], " << *getNumElements();
if (isStrided()) {
*p << ", " << *getStride();
*p << ", " << *getNumElementsPerStride();
}
*p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
<< getTagMemRefType();
}
// Parse AffineDmaStartOp.
// Ex:
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
// %stride, %num_elt_per_stride
// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
//
ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType srcMemRefInfo;
AffineMapAttr srcMapAttr;
SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
OpAsmParser::OperandType dstMemRefInfo;
AffineMapAttr dstMapAttr;
SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
OpAsmParser::OperandType tagMemRefInfo;
AffineMapAttr tagMapAttr;
SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
OpAsmParser::OperandType numElementsInfo;
SmallVector<OpAsmParser::OperandType, 2> strideInfo;
SmallVector<Type, 3> types;
auto indexType = parser->getBuilder().getIndexType();
// Parse and resolve the following list of operands:
// *) dst memref followed by its affine maps operands (in square brackets).
// *) src memref followed by its affine map operands (in square brackets).
// *) tag memref followed by its affine map operands (in square brackets).
// *) number of elements transferred by DMA operation.
if (parser->parseOperand(srcMemRefInfo) || parser->parseLSquare() ||
parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
getSrcMapAttrName(), result->attributes) ||
parser->parseRSquare() || parser->parseComma() ||
parser->parseOperand(dstMemRefInfo) || parser->parseLSquare() ||
parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
getDstMapAttrName(), result->attributes) ||
parser->parseRSquare() || parser->parseComma() ||
parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
getTagMapAttrName(), result->attributes) ||
parser->parseRSquare() || parser->parseComma() ||
parser->parseOperand(numElementsInfo))
return failure();
// Parse optional stride and elements per stride.
if (parser->parseTrailingOperandList(strideInfo)) {
return failure();
}
if (!strideInfo.empty() && strideInfo.size() != 2) {
return parser->emitError(parser->getNameLoc(),
"expected two stride related operands");
}
bool isStrided = strideInfo.size() == 2;
if (parser->parseColonTypeList(types))
return failure();
if (types.size() != 3)
return parser->emitError(parser->getNameLoc(), "expected three types");
if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return failure();
if (isStrided) {
if (parser->resolveOperands(strideInfo, indexType, result->operands))
return failure();
}
// Check that src/dst/tag operand counts match their map.numInputs.
if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
return parser->emitError(parser->getNameLoc(),
"memref operand count not equal to map.numInputs");
return success();
}
LogicalResult AffineDmaStartOp::verify() {
if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
return emitOpError("expected DMA source to be of memref type");
if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
return emitOpError("expected DMA destination to be of memref type");
if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
// DMAs from different memory spaces supported.
if (getSrcMemorySpace() == getDstMemorySpace()) {
return emitOpError("DMA should be between different memory spaces");
}
unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
getDstMap().getNumInputs() +
getTagMap().getNumInputs();
if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
return emitOpError("incorrect number of operands");
}
return success();
}
//===----------------------------------------------------------------------===//
// AffineDmaWaitOp
//===----------------------------------------------------------------------===//
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
Value *tagMemRef, AffineMap tagMap,
ArrayRef<Value *> tagIndices, Value *numElements) {
result->addOperands(tagMemRef);
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
result->addOperands(tagIndices);
result->addOperands(numElements);
}
void AffineDmaWaitOp::print(OpAsmPrinter *p) {
*p << "affine.dma_wait " << *getTagMemRef() << '[';
SmallVector<Value *, 2> operands(getTagIndices());
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
*p << "], ";
p->printOperand(getNumElements());
*p << " : " << getTagMemRef()->getType();
}
// Parse AffineDmaWaitOp.
// Eg:
// affine.dma_wait %tag[%index], %num_elements
// : memref<1 x i32, (d0) -> (d0), 4>
//
ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType tagMemRefInfo;
AffineMapAttr tagMapAttr;
SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
Type type;
auto indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its map operands, and dma size.
if (parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
getTagMapAttrName(), result->attributes) ||
parser->parseRSquare() || parser->parseComma() ||
parser->parseOperand(numElementsInfo) || parser->parseColonType(type) ||
parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return failure();
if (!type.isa<MemRefType>())
return parser->emitError(parser->getNameLoc(),
"expected tag to be of memref type");
if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
return parser->emitError(parser->getNameLoc(),
"tag memref operand count != to map.numInputs");
return success();
}
LogicalResult AffineDmaWaitOp::verify() {
if (!getOperand(0)->getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
return success();
}
//===----------------------------------------------------------------------===//
// AffineForOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,127 @@
// RUN: mlir-opt %s -split-input-file | FileCheck %s
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
// Test with loop IVs.
func @test0(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
// Test with loop IVs and optional stride arguments.
func @test1(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
%c128 = constant 128 : index
%c256 = constant 256 : index
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1 + d2 + 5)
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0 + d1, d2)
// Test with loop IVs and symbols (without symbol keyword).
func @test2(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5],
%2[%c0], %c64
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
// CHECK: affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> (d0 + s0, d1)
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
// Test with loop IVs and symbols (with symbol keyword).
func @test3(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
affine.dma_start %0[%i0 + symbol(%arg0), %i1],
%1[%i0, %i1 + symbol(%arg1) + 7],
%2[%i0 + %i1 + 11], %c64
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
// CHECK: affine.dma_start %0[%i0 + symbol(%arg0), %i1], %1[%i0, %i1 + symbol(%arg1) + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
// Test with loop IVs, symbols and constants in nested affine expressions.
func @test4(%arg0 : index, %arg1 : index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, 2>
%2 = alloc() : memref<1xi32>
%c64 = constant 64 : index
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1],
%1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7],
%2[%i0 + %i1 + 11], %c64
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
// CHECK: affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1], %1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
// CHECK: affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
}
}
return
}