diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 0aeb98560c2f..91b0d0e7d98b 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -90,6 +90,285 @@ public: MLIRContext *context); }; +/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data +/// from a source memref to a destination memref. The source and destination +/// memref need not be of the same dimensionality, but need to have the same +/// elemental type. The operands include the source and destination memref's +/// each followed by its indices, size of the data transfer in terms of the +/// number of elements (of the elemental type of the memref), a tag memref with +/// its indices, and optionally at the end, a stride and a +/// number_of_elements_per_stride arguments. The tag location is used by an +/// AffineDmaWaitOp to check for completion. The indices of the source memref, +/// destination memref, and the tag memref have the same restrictions as any +/// affine.load/store. In particular, index for each memref dimension must be an +/// affine expression of loop induction variables and symbols. +/// The optional stride arguments should be of 'index' type, and specify a +/// stride for the slower memory space (memory space with a lower memory space +/// id), tranferring chunks of number_of_elements_per_stride every stride until +/// %num_elements are transferred. Either both or no stride arguments should be +/// specified. The value of 'num_elements' must be a multiple of +/// 'number_of_elements_per_stride'. +// +// For example, a DmaStartOp operation that transfers 256 elements of a memref +// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory +// space 1 at indices [%k + 7, %l], would be specified as follows: +// +// %num_elements = constant 256 +// %idx = constant 0 : index +// %tag = alloc() : memref<1xi32, 4> +// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], +// %num_elements : +// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> +// +// If %stride and %num_elt_per_stride are specified, the DMA is expected to +// transfer %num_elt_per_stride elements every %stride elements apart from +// memory space 0 until %num_elements are transferred. +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, +// %stride, %num_elt_per_stride : ... +// +// TODO(mlir-team): add additional operands to allow source and destination +// striding, and multiple stride levels (possibly using AffineMaps to specify +// multiple levels of striding). +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class AffineDmaStartOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *srcMemRef, + AffineMap srcMap, ArrayRef srcIndices, + Value *destMemRef, AffineMap dstMap, + ArrayRef destIndices, Value *tagMemRef, + AffineMap tagMap, ArrayRef tagIndices, + Value *numElements, Value *stride = nullptr, + Value *elementsPerStride = nullptr); + + /// Returns the operand index of the src memref. + unsigned getSrcMemRefOperandIndex() { return 0; } + + /// Returns the source MemRefType for this DMA operation. + Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } + MemRefType getSrcMemRefType() { + return getSrcMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } + + /// Returns the affine map used to access the src memref. + AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } + AffineMapAttr getSrcMapAttr() { + return getAttr(getSrcMapAttrName()).cast(); + } + + /// Returns the source memref affine map indices for this DMA operation. + operand_range getSrcIndices() { + return {operand_begin() + getSrcMemRefOperandIndex() + 1, + operand_begin() + getSrcMemRefOperandIndex() + 1 + + getSrcMap().getNumInputs()}; + } + + /// Returns the memory space of the src memref. + unsigned getSrcMemorySpace() { + return getSrcMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the operand index of the dst memref. + unsigned getDstMemRefOperandIndex() { + return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); + } + + /// Returns the destination MemRefType for this DMA operations. + Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } + MemRefType getDstMemRefType() { + return getDstMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() { + return getDstMemRef()->getType().cast().getRank(); + } + + /// Returns the memory space of the src memref. + unsigned getDstMemorySpace() { + return getDstMemRef()->getType().cast().getMemorySpace(); + } + + /// Returns the affine map used to access the dst memref. + AffineMap getDstMap() { return getDstMapAttr().getValue(); } + AffineMapAttr getDstMapAttr() { + return getAttr(getDstMapAttrName()).cast(); + } + + /// Returns the destination memref indices for this DMA operation. + operand_range getDstIndices() { + return {operand_begin() + getDstMemRefOperandIndex() + 1, + operand_begin() + getDstMemRefOperandIndex() + 1 + + getDstMap().getNumInputs()}; + } + + /// Returns the operand index of the tag memref. + unsigned getTagMemRefOperandIndex() { + return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); + } + + /// Returns the Tag MemRef for this DMA operation. + Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the rank (number of indices) of the tag MemRefType. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + /// Returns the tag memref indices for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + getTagMemRefOperandIndex() + 1, + operand_begin() + getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()}; + } + + /// Returns the number of elements being transferred by this DMA operation. + Value *getNumElements() { + return getOperand(getTagMemRefOperandIndex() + 1 + + getTagMap().getNumInputs()); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + if (memref == getSrcMemRef()) + return {Identifier::get(getSrcMapAttrName(), getContext()), + getSrcMapAttr()}; + else if (memref == getDstMemRef()) + return {Identifier::get(getDstMapAttrName(), getContext()), + getDstMapAttr()}; + assert(memref == getTagMemRef() && + "DmaStartOp expected source, destination or tag memref"); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns true if this is a DMA from a faster memory space to a slower one. + bool isDestMemorySpaceFaster() { + return (getSrcMemorySpace() < getDstMemorySpace()); + } + + /// Returns true if this is a DMA from a slower memory space to a faster one. + bool isSrcMemorySpaceFaster() { + // Assumes that a lower number is for a slower memory space. + return (getDstMemorySpace() < getSrcMemorySpace()); + } + + /// Given a DMA start operation, returns the operand position of either the + /// source or destination memref depending on the one that is at the higher + /// level of the memory hierarchy. Asserts failure if neither is true. + unsigned getFasterMemPos() { + assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); + return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); + } + + static StringRef getSrcMapAttrName() { return "src_map"; } + static StringRef getDstMapAttrName() { return "dst_map"; } + static StringRef getTagMapAttrName() { return "tag_map"; } + + static StringRef getOperationName() { return "affine.dma_start"; } + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + + /// Returns true if this DMA operation is strided, returns false otherwise. + bool isStrided() { + return getNumOperands() != + getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; + } + + /// Returns the stride value for this DMA operation. + Value *getStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1 - 1); + } + + /// Returns the number of elements to transfer per stride for this DMA op. + Value *getNumElementsPerStride() { + if (!isStrided()) + return nullptr; + return getOperand(getNumOperands() - 1); + } +}; + +/// AffineDmaWaitOp blocks until the completion of a DMA operation associated +/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be +/// an index with the same restrictions as any load/store index. In particular, +/// index for each memref dimension must be an affine expression of loop +/// induction variables and symbols. %num_elements is the number of elements +/// associated with the DMA operation. For example: +// +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : +// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> +// ... +// ... +// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> +// +class AffineDmaWaitOp : public Op { +public: + using Op::Op; + + static void build(Builder *builder, OperationState *result, Value *tagMemRef, + AffineMap tagMap, ArrayRef tagIndices, + Value *numElements); + + static StringRef getOperationName() { return "affine.dma_wait"; } + + // Returns the Tag MemRef associated with the DMA operation being waited on. + Value *getTagMemRef() { return getOperand(0); } + MemRefType getTagMemRefType() { + return getTagMemRef()->getType().cast(); + } + + /// Returns the affine map used to access the tag memref. + AffineMap getTagMap() { return getTagMapAttr().getValue(); } + AffineMapAttr getTagMapAttr() { + return getAttr(getTagMapAttrName()).cast(); + } + + // Returns the tag memref index for this DMA operation. + operand_range getTagIndices() { + return {operand_begin() + 1, + operand_begin() + 1 + getTagMap().getNumInputs()}; + } + + // Returns the rank (number of indices) of the tag memref. + unsigned getTagMemRefRank() { + return getTagMemRef()->getType().cast().getRank(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getTagMemRef()); + return {Identifier::get(getTagMapAttrName(), getContext()), + getTagMapAttr()}; + } + + /// Returns the number of elements transferred in the associated DMA op. + Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } + + static StringRef getTagMapAttrName() { return "tag_map"; } + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); +}; + /// The "affine.for" operation represents an affine loop nest, defining an SSA /// value for its induction variable. It has one region capturing the loop body. /// The induction variable is represented as a argument of this region. This SSA @@ -382,10 +661,19 @@ public: operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); } /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { - return getAttrOfType("map").getValue(); + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); } + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } static StringRef getOperationName() { return "affine.load"; } // Hooks to customize behavior of this op. @@ -435,10 +723,19 @@ public: operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); } /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { - return getAttrOfType("map").getValue(); + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); } + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value *memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } static StringRef getOperationName() { return "affine.store"; } // Hooks to customize behavior of this op. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index e35847feb963..016ef43a84a1 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -37,8 +37,8 @@ using llvm::dbgs; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addOperations(); + addOperations(); } /// A utility function to check if a value is defined at the top level of a @@ -696,6 +696,220 @@ void AffineApplyOp::getCanonicalizationPatterns( results.push_back(llvm::make_unique(context)); } +//===----------------------------------------------------------------------===// +// AffineDmaStartOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaStartOp::build(Builder *builder, OperationState *result, + Value *srcMemRef, AffineMap srcMap, + ArrayRef srcIndices, Value *destMemRef, + AffineMap dstMap, ArrayRef destIndices, + Value *tagMemRef, AffineMap tagMap, + ArrayRef tagIndices, Value *numElements, + Value *stride, Value *elementsPerStride) { + result->addOperands(srcMemRef); + result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); + result->addOperands(srcIndices); + result->addOperands(destMemRef); + result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); + result->addOperands(destIndices); + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); + if (stride) { + result->addOperands({stride, elementsPerStride}); + } +} + +void AffineDmaStartOp::print(OpAsmPrinter *p) { + *p << "affine.dma_start " << *getSrcMemRef() << '['; + SmallVector operands(getSrcIndices()); + p->printAffineMapOfSSAIds(getSrcMapAttr(), operands); + *p << "], " << *getDstMemRef() << '['; + operands.assign(getDstIndices().begin(), getDstIndices().end()); + p->printAffineMapOfSSAIds(getDstMapAttr(), operands); + *p << "], " << *getTagMemRef() << '['; + operands.assign(getTagIndices().begin(), getTagIndices().end()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], " << *getNumElements(); + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } + *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " + << getTagMemRefType(); +} + +// Parse AffineDmaStartOp. +// Ex: +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, +// %stride, %num_elt_per_stride +// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> +// +ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + AffineMapAttr srcMapAttr; + SmallVector srcMapOperands; + OpAsmParser::OperandType dstMemRefInfo; + AffineMapAttr dstMapAttr; + SmallVector dstMapOperands; + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector tagMapOperands; + OpAsmParser::OperandType numElementsInfo; + SmallVector strideInfo; + + SmallVector types; + auto indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) dst memref followed by its affine maps operands (in square brackets). + // *) src memref followed by its affine map operands (in square brackets). + // *) tag memref followed by its affine map operands (in square brackets). + // *) number of elements transferred by DMA operation. + if (parser->parseOperand(srcMemRefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, + getSrcMapAttrName(), result->attributes) || + parser->parseRSquare() || parser->parseComma() || + parser->parseOperand(dstMemRefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, + getDstMapAttrName(), result->attributes) || + parser->parseRSquare() || parser->parseComma() || + parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseRSquare() || parser->parseComma() || + parser->parseOperand(numElementsInfo)) + return failure(); + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) { + return failure(); + } + if (!strideInfo.empty() && strideInfo.size() != 2) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + bool isStrided = strideInfo.size() == 2; + + if (parser->parseColonTypeList(types)) + return failure(); + + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "expected three types"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcMapOperands, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstMapOperands, indexType, result->operands) || + parser->resolveOperand(tagMemRefInfo, types[2], result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (isStrided) { + if (parser->resolveOperands(strideInfo, indexType, result->operands)) + return failure(); + } + + // Check that src/dst/tag operand counts match their map.numInputs. + if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || + dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || + tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "memref operand count not equal to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaStartOp::verify() { + if (!getOperand(getSrcMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA source to be of memref type"); + if (!getOperand(getDstMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA destination to be of memref type"); + if (!getOperand(getTagMemRefOperandIndex())->getType().isa()) + return emitOpError("expected DMA tag to be of memref type"); + + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) { + return emitOpError("DMA should be between different memory spaces"); + } + unsigned numInputsAllMaps = getSrcMap().getNumInputs() + + getDstMap().getNumInputs() + + getTagMap().getNumInputs(); + if (getNumOperands() != numInputsAllMaps + 3 + 1 && + getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AffineDmaWaitOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaWaitOp::build(Builder *builder, OperationState *result, + Value *tagMemRef, AffineMap tagMap, + ArrayRef tagIndices, Value *numElements) { + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); +} + +void AffineDmaWaitOp::print(OpAsmPrinter *p) { + *p << "affine.dma_wait " << *getTagMemRef() << '['; + SmallVector operands(getTagIndices()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], "; + p->printOperand(getNumElements()); + *p << " : " << getTagMemRef()->getType(); +} + +// Parse AffineDmaWaitOp. +// Eg: +// affine.dma_wait %tag[%index], %num_elements +// : memref<1 x i32, (d0) -> (d0), 4> +// +ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector tagMapOperands; + Type type; + auto indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its map operands, and dma size. + if (parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseRSquare() || parser->parseComma() || + parser->parseOperand(numElementsInfo) || parser->parseColonType(type) || + parser->resolveOperand(tagMemRefInfo, type, result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (!type.isa()) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "tag memref operand count != to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaWaitOp::verify() { + if (!getOperand(0)->getType().isa()) + return emitOpError("expected DMA tag to be of memref type"); + return success(); +} + //===----------------------------------------------------------------------===// // AffineForOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/AffineOps/dma.mlir b/mlir/test/AffineOps/dma.mlir new file mode 100644 index 000000000000..68acbe2229cb --- /dev/null +++ b/mlir/test/AffineOps/dma.mlir @@ -0,0 +1,127 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1) + +// Test with loop IVs. +func @test0(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + %1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2> + %2 = alloc() : memref<1xi32> + %c0 = constant 0 : index + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64 + : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> + affine.dma_wait %2[%c0], %c64 : memref<1xi32> +// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1) + +// Test with loop IVs and optional stride arguments. +func @test1(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + %1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2> + %2 = alloc() : memref<1xi32> + %c0 = constant 0 : index + %c64 = constant 64 : index + %c128 = constant 128 : index + %c256 = constant 256 : index + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256 + : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> + affine.dma_wait %2[%c0], %c64 : memref<1xi32> +// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1 + d2 + 5) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0 + d1, d2) + +// Test with loop IVs and symbols (without symbol keyword). +func @test2(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + %1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2> + %2 = alloc() : memref<1xi32> + %c0 = constant 0 : index + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5], + %2[%c0], %c64 + : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> + affine.dma_wait %2[%c0], %c64 : memref<1xi32> +// CHECK: affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, d1 + s0 + 7) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> (d0 + s0, d1) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11) + +// Test with loop IVs and symbols (with symbol keyword). +func @test3(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + %1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2> + %2 = alloc() : memref<1xi32> + %c0 = constant 0 : index + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.dma_start %0[%i0 + symbol(%arg0), %i1], + %1[%i0, %i1 + symbol(%arg1) + 7], + %2[%i0 + %i1 + 11], %c64 + : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> + affine.dma_wait %2[%c0], %c64 : memref<1xi32> +// CHECK: affine.dma_start %0[%i0 + symbol(%arg0), %i1], %1[%i0, %i1 + symbol(%arg1) + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7) +// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1) +// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11) + +// Test with loop IVs, symbols and constants in nested affine expressions. +func @test4(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + %1 = alloc() : memref<100x100xf32, 2> + %2 = alloc() : memref<1xi32> + %c64 = constant 64 : index + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1], + %1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7], + %2[%i0 + %i1 + 11], %c64 + : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> + affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32> +// CHECK: affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1], %1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> +// CHECK: affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32> + } + } + return +}