forked from OSchip/llvm-project
Add new operations affine.dma_start and affine.dma_wait which take affine maps for indexing memrefs by construction.
These ops are analogues of the current standard ops dma_start/wait, with the exception that the memref operands are affine expressions of loop IVs and symbols (analogous to affine.load/store). The addition of these operations will enable changes to affine transformation and analysis passes which operate on memory dereferencing operations. PiperOrigin-RevId: 255658382
This commit is contained in:
parent
7b17f4e647
commit
6c68596aee
|
@ -90,6 +90,285 @@ public:
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
|
||||||
|
/// from a source memref to a destination memref. The source and destination
|
||||||
|
/// memref need not be of the same dimensionality, but need to have the same
|
||||||
|
/// elemental type. The operands include the source and destination memref's
|
||||||
|
/// each followed by its indices, size of the data transfer in terms of the
|
||||||
|
/// number of elements (of the elemental type of the memref), a tag memref with
|
||||||
|
/// its indices, and optionally at the end, a stride and a
|
||||||
|
/// number_of_elements_per_stride arguments. The tag location is used by an
|
||||||
|
/// AffineDmaWaitOp to check for completion. The indices of the source memref,
|
||||||
|
/// destination memref, and the tag memref have the same restrictions as any
|
||||||
|
/// affine.load/store. In particular, index for each memref dimension must be an
|
||||||
|
/// affine expression of loop induction variables and symbols.
|
||||||
|
/// The optional stride arguments should be of 'index' type, and specify a
|
||||||
|
/// stride for the slower memory space (memory space with a lower memory space
|
||||||
|
/// id), tranferring chunks of number_of_elements_per_stride every stride until
|
||||||
|
/// %num_elements are transferred. Either both or no stride arguments should be
|
||||||
|
/// specified. The value of 'num_elements' must be a multiple of
|
||||||
|
/// 'number_of_elements_per_stride'.
|
||||||
|
//
|
||||||
|
// For example, a DmaStartOp operation that transfers 256 elements of a memref
|
||||||
|
// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
|
||||||
|
// space 1 at indices [%k + 7, %l], would be specified as follows:
|
||||||
|
//
|
||||||
|
// %num_elements = constant 256
|
||||||
|
// %idx = constant 0 : index
|
||||||
|
// %tag = alloc() : memref<1xi32, 4>
|
||||||
|
// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
|
||||||
|
// %num_elements :
|
||||||
|
// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
|
||||||
|
//
|
||||||
|
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
||||||
|
// transfer %num_elt_per_stride elements every %stride elements apart from
|
||||||
|
// memory space 0 until %num_elements are transferred.
|
||||||
|
//
|
||||||
|
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
|
||||||
|
// %stride, %num_elt_per_stride : ...
|
||||||
|
//
|
||||||
|
// TODO(mlir-team): add additional operands to allow source and destination
|
||||||
|
// striding, and multiple stride levels (possibly using AffineMaps to specify
|
||||||
|
// multiple levels of striding).
|
||||||
|
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
|
||||||
|
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
|
||||||
|
OpTrait::ZeroResult> {
|
||||||
|
public:
|
||||||
|
using Op::Op;
|
||||||
|
|
||||||
|
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
|
||||||
|
AffineMap srcMap, ArrayRef<Value *> srcIndices,
|
||||||
|
Value *destMemRef, AffineMap dstMap,
|
||||||
|
ArrayRef<Value *> destIndices, Value *tagMemRef,
|
||||||
|
AffineMap tagMap, ArrayRef<Value *> tagIndices,
|
||||||
|
Value *numElements, Value *stride = nullptr,
|
||||||
|
Value *elementsPerStride = nullptr);
|
||||||
|
|
||||||
|
/// Returns the operand index of the src memref.
|
||||||
|
unsigned getSrcMemRefOperandIndex() { return 0; }
|
||||||
|
|
||||||
|
/// Returns the source MemRefType for this DMA operation.
|
||||||
|
Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
|
||||||
|
MemRefType getSrcMemRefType() {
|
||||||
|
return getSrcMemRef()->getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the rank (number of indices) of the source MemRefType.
|
||||||
|
unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
|
||||||
|
|
||||||
|
/// Returns the affine map used to access the src memref.
|
||||||
|
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
|
||||||
|
AffineMapAttr getSrcMapAttr() {
|
||||||
|
return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the source memref affine map indices for this DMA operation.
|
||||||
|
operand_range getSrcIndices() {
|
||||||
|
return {operand_begin() + getSrcMemRefOperandIndex() + 1,
|
||||||
|
operand_begin() + getSrcMemRefOperandIndex() + 1 +
|
||||||
|
getSrcMap().getNumInputs()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the memory space of the src memref.
|
||||||
|
unsigned getSrcMemorySpace() {
|
||||||
|
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the operand index of the dst memref.
|
||||||
|
unsigned getDstMemRefOperandIndex() {
|
||||||
|
return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the destination MemRefType for this DMA operations.
|
||||||
|
Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
|
||||||
|
MemRefType getDstMemRefType() {
|
||||||
|
return getDstMemRef()->getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the rank (number of indices) of the destination MemRefType.
|
||||||
|
unsigned getDstMemRefRank() {
|
||||||
|
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the memory space of the src memref.
|
||||||
|
unsigned getDstMemorySpace() {
|
||||||
|
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the affine map used to access the dst memref.
|
||||||
|
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
|
||||||
|
AffineMapAttr getDstMapAttr() {
|
||||||
|
return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the destination memref indices for this DMA operation.
|
||||||
|
operand_range getDstIndices() {
|
||||||
|
return {operand_begin() + getDstMemRefOperandIndex() + 1,
|
||||||
|
operand_begin() + getDstMemRefOperandIndex() + 1 +
|
||||||
|
getDstMap().getNumInputs()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the operand index of the tag memref.
|
||||||
|
unsigned getTagMemRefOperandIndex() {
|
||||||
|
return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the Tag MemRef for this DMA operation.
|
||||||
|
Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
|
||||||
|
MemRefType getTagMemRefType() {
|
||||||
|
return getTagMemRef()->getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the rank (number of indices) of the tag MemRefType.
|
||||||
|
unsigned getTagMemRefRank() {
|
||||||
|
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the affine map used to access the tag memref.
|
||||||
|
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
|
||||||
|
AffineMapAttr getTagMapAttr() {
|
||||||
|
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the tag memref indices for this DMA operation.
|
||||||
|
operand_range getTagIndices() {
|
||||||
|
return {operand_begin() + getTagMemRefOperandIndex() + 1,
|
||||||
|
operand_begin() + getTagMemRefOperandIndex() + 1 +
|
||||||
|
getTagMap().getNumInputs()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of elements being transferred by this DMA operation.
|
||||||
|
Value *getNumElements() {
|
||||||
|
return getOperand(getTagMemRefOperandIndex() + 1 +
|
||||||
|
getTagMap().getNumInputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the AffineMapAttr associated with 'memref'.
|
||||||
|
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||||
|
if (memref == getSrcMemRef())
|
||||||
|
return {Identifier::get(getSrcMapAttrName(), getContext()),
|
||||||
|
getSrcMapAttr()};
|
||||||
|
else if (memref == getDstMemRef())
|
||||||
|
return {Identifier::get(getDstMapAttrName(), getContext()),
|
||||||
|
getDstMapAttr()};
|
||||||
|
assert(memref == getTagMemRef() &&
|
||||||
|
"DmaStartOp expected source, destination or tag memref");
|
||||||
|
return {Identifier::get(getTagMapAttrName(), getContext()),
|
||||||
|
getTagMapAttr()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
||||||
|
bool isDestMemorySpaceFaster() {
|
||||||
|
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
||||||
|
bool isSrcMemorySpaceFaster() {
|
||||||
|
// Assumes that a lower number is for a slower memory space.
|
||||||
|
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a DMA start operation, returns the operand position of either the
|
||||||
|
/// source or destination memref depending on the one that is at the higher
|
||||||
|
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||||
|
unsigned getFasterMemPos() {
|
||||||
|
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||||
|
return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
|
||||||
|
}
|
||||||
|
|
||||||
|
static StringRef getSrcMapAttrName() { return "src_map"; }
|
||||||
|
static StringRef getDstMapAttrName() { return "dst_map"; }
|
||||||
|
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||||
|
|
||||||
|
static StringRef getOperationName() { return "affine.dma_start"; }
|
||||||
|
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||||
|
void print(OpAsmPrinter *p);
|
||||||
|
LogicalResult verify();
|
||||||
|
|
||||||
|
/// Returns true if this DMA operation is strided, returns false otherwise.
|
||||||
|
bool isStrided() {
|
||||||
|
return getNumOperands() !=
|
||||||
|
getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the stride value for this DMA operation.
|
||||||
|
Value *getStride() {
|
||||||
|
if (!isStrided())
|
||||||
|
return nullptr;
|
||||||
|
return getOperand(getNumOperands() - 1 - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of elements to transfer per stride for this DMA op.
|
||||||
|
Value *getNumElementsPerStride() {
|
||||||
|
if (!isStrided())
|
||||||
|
return nullptr;
|
||||||
|
return getOperand(getNumOperands() - 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
|
||||||
|
/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
|
||||||
|
/// an index with the same restrictions as any load/store index. In particular,
|
||||||
|
/// index for each memref dimension must be an affine expression of loop
|
||||||
|
/// induction variables and symbols. %num_elements is the number of elements
|
||||||
|
/// associated with the DMA operation. For example:
|
||||||
|
//
|
||||||
|
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
|
||||||
|
// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
|
||||||
|
// ...
|
||||||
|
// ...
|
||||||
|
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
|
||||||
|
//
|
||||||
|
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
|
||||||
|
OpTrait::ZeroResult> {
|
||||||
|
public:
|
||||||
|
using Op::Op;
|
||||||
|
|
||||||
|
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
|
||||||
|
AffineMap tagMap, ArrayRef<Value *> tagIndices,
|
||||||
|
Value *numElements);
|
||||||
|
|
||||||
|
static StringRef getOperationName() { return "affine.dma_wait"; }
|
||||||
|
|
||||||
|
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||||
|
Value *getTagMemRef() { return getOperand(0); }
|
||||||
|
MemRefType getTagMemRefType() {
|
||||||
|
return getTagMemRef()->getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the affine map used to access the tag memref.
|
||||||
|
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
|
||||||
|
AffineMapAttr getTagMapAttr() {
|
||||||
|
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the tag memref index for this DMA operation.
|
||||||
|
operand_range getTagIndices() {
|
||||||
|
return {operand_begin() + 1,
|
||||||
|
operand_begin() + 1 + getTagMap().getNumInputs()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the rank (number of indices) of the tag memref.
|
||||||
|
unsigned getTagMemRefRank() {
|
||||||
|
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the AffineMapAttr associated with 'memref'.
|
||||||
|
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||||
|
assert(memref == getTagMemRef());
|
||||||
|
return {Identifier::get(getTagMapAttrName(), getContext()),
|
||||||
|
getTagMapAttr()};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of elements transferred in the associated DMA op.
|
||||||
|
Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
|
||||||
|
|
||||||
|
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||||
|
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||||
|
void print(OpAsmPrinter *p);
|
||||||
|
LogicalResult verify();
|
||||||
|
};
|
||||||
|
|
||||||
/// The "affine.for" operation represents an affine loop nest, defining an SSA
|
/// The "affine.for" operation represents an affine loop nest, defining an SSA
|
||||||
/// value for its induction variable. It has one region capturing the loop body.
|
/// value for its induction variable. It has one region capturing the loop body.
|
||||||
/// The induction variable is represented as a argument of this region. This SSA
|
/// The induction variable is represented as a argument of this region. This SSA
|
||||||
|
@ -382,10 +661,19 @@ public:
|
||||||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
|
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
|
||||||
|
|
||||||
/// Returns the affine map used to index the memref for this operation.
|
/// Returns the affine map used to index the memref for this operation.
|
||||||
AffineMap getAffineMap() {
|
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||||
return getAttrOfType<AffineMapAttr>("map").getValue();
|
AffineMapAttr getAffineMapAttr() {
|
||||||
|
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the AffineMapAttr associated with 'memref'.
|
||||||
|
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||||
|
assert(memref == getMemRef());
|
||||||
|
return {Identifier::get(getMapAttrName(), getContext()),
|
||||||
|
getAffineMapAttr()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static StringRef getMapAttrName() { return "map"; }
|
||||||
static StringRef getOperationName() { return "affine.load"; }
|
static StringRef getOperationName() { return "affine.load"; }
|
||||||
|
|
||||||
// Hooks to customize behavior of this op.
|
// Hooks to customize behavior of this op.
|
||||||
|
@ -435,10 +723,19 @@ public:
|
||||||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
|
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
|
||||||
|
|
||||||
/// Returns the affine map used to index the memref for this operation.
|
/// Returns the affine map used to index the memref for this operation.
|
||||||
AffineMap getAffineMap() {
|
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||||
return getAttrOfType<AffineMapAttr>("map").getValue();
|
AffineMapAttr getAffineMapAttr() {
|
||||||
|
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the AffineMapAttr associated with 'memref'.
|
||||||
|
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||||
|
assert(memref == getMemRef());
|
||||||
|
return {Identifier::get(getMapAttrName(), getContext()),
|
||||||
|
getAffineMapAttr()};
|
||||||
|
}
|
||||||
|
|
||||||
|
static StringRef getMapAttrName() { return "map"; }
|
||||||
static StringRef getOperationName() { return "affine.store"; }
|
static StringRef getOperationName() { return "affine.store"; }
|
||||||
|
|
||||||
// Hooks to customize behavior of this op.
|
// Hooks to customize behavior of this op.
|
||||||
|
|
|
@ -37,8 +37,8 @@ using llvm::dbgs;
|
||||||
|
|
||||||
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
|
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineLoadOp,
|
addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineForOp,
|
||||||
AffineStoreOp, AffineTerminatorOp>();
|
AffineIfOp, AffineLoadOp, AffineStoreOp, AffineTerminatorOp>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A utility function to check if a value is defined at the top level of a
|
/// A utility function to check if a value is defined at the top level of a
|
||||||
|
@ -696,6 +696,220 @@ void AffineApplyOp::getCanonicalizationPatterns(
|
||||||
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
|
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AffineDmaStartOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
|
||||||
|
void AffineDmaStartOp::build(Builder *builder, OperationState *result,
|
||||||
|
Value *srcMemRef, AffineMap srcMap,
|
||||||
|
ArrayRef<Value *> srcIndices, Value *destMemRef,
|
||||||
|
AffineMap dstMap, ArrayRef<Value *> destIndices,
|
||||||
|
Value *tagMemRef, AffineMap tagMap,
|
||||||
|
ArrayRef<Value *> tagIndices, Value *numElements,
|
||||||
|
Value *stride, Value *elementsPerStride) {
|
||||||
|
result->addOperands(srcMemRef);
|
||||||
|
result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
|
||||||
|
result->addOperands(srcIndices);
|
||||||
|
result->addOperands(destMemRef);
|
||||||
|
result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
|
||||||
|
result->addOperands(destIndices);
|
||||||
|
result->addOperands(tagMemRef);
|
||||||
|
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
|
||||||
|
result->addOperands(tagIndices);
|
||||||
|
result->addOperands(numElements);
|
||||||
|
if (stride) {
|
||||||
|
result->addOperands({stride, elementsPerStride});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AffineDmaStartOp::print(OpAsmPrinter *p) {
|
||||||
|
*p << "affine.dma_start " << *getSrcMemRef() << '[';
|
||||||
|
SmallVector<Value *, 8> operands(getSrcIndices());
|
||||||
|
p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
|
||||||
|
*p << "], " << *getDstMemRef() << '[';
|
||||||
|
operands.assign(getDstIndices().begin(), getDstIndices().end());
|
||||||
|
p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
|
||||||
|
*p << "], " << *getTagMemRef() << '[';
|
||||||
|
operands.assign(getTagIndices().begin(), getTagIndices().end());
|
||||||
|
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||||
|
*p << "], " << *getNumElements();
|
||||||
|
if (isStrided()) {
|
||||||
|
*p << ", " << *getStride();
|
||||||
|
*p << ", " << *getNumElementsPerStride();
|
||||||
|
}
|
||||||
|
*p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
|
||||||
|
<< getTagMemRefType();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse AffineDmaStartOp.
|
||||||
|
// Ex:
|
||||||
|
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
|
||||||
|
// %stride, %num_elt_per_stride
|
||||||
|
// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
|
||||||
|
//
|
||||||
|
ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
|
||||||
|
OperationState *result) {
|
||||||
|
OpAsmParser::OperandType srcMemRefInfo;
|
||||||
|
AffineMapAttr srcMapAttr;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
|
||||||
|
OpAsmParser::OperandType dstMemRefInfo;
|
||||||
|
AffineMapAttr dstMapAttr;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
|
||||||
|
OpAsmParser::OperandType tagMemRefInfo;
|
||||||
|
AffineMapAttr tagMapAttr;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
|
||||||
|
OpAsmParser::OperandType numElementsInfo;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 2> strideInfo;
|
||||||
|
|
||||||
|
SmallVector<Type, 3> types;
|
||||||
|
auto indexType = parser->getBuilder().getIndexType();
|
||||||
|
|
||||||
|
// Parse and resolve the following list of operands:
|
||||||
|
// *) dst memref followed by its affine maps operands (in square brackets).
|
||||||
|
// *) src memref followed by its affine map operands (in square brackets).
|
||||||
|
// *) tag memref followed by its affine map operands (in square brackets).
|
||||||
|
// *) number of elements transferred by DMA operation.
|
||||||
|
if (parser->parseOperand(srcMemRefInfo) || parser->parseLSquare() ||
|
||||||
|
parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
|
||||||
|
getSrcMapAttrName(), result->attributes) ||
|
||||||
|
parser->parseRSquare() || parser->parseComma() ||
|
||||||
|
parser->parseOperand(dstMemRefInfo) || parser->parseLSquare() ||
|
||||||
|
parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
|
||||||
|
getDstMapAttrName(), result->attributes) ||
|
||||||
|
parser->parseRSquare() || parser->parseComma() ||
|
||||||
|
parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
|
||||||
|
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
|
||||||
|
getTagMapAttrName(), result->attributes) ||
|
||||||
|
parser->parseRSquare() || parser->parseComma() ||
|
||||||
|
parser->parseOperand(numElementsInfo))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Parse optional stride and elements per stride.
|
||||||
|
if (parser->parseTrailingOperandList(strideInfo)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!strideInfo.empty() && strideInfo.size() != 2) {
|
||||||
|
return parser->emitError(parser->getNameLoc(),
|
||||||
|
"expected two stride related operands");
|
||||||
|
}
|
||||||
|
bool isStrided = strideInfo.size() == 2;
|
||||||
|
|
||||||
|
if (parser->parseColonTypeList(types))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (types.size() != 3)
|
||||||
|
return parser->emitError(parser->getNameLoc(), "expected three types");
|
||||||
|
|
||||||
|
if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
|
||||||
|
parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
|
||||||
|
parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
|
||||||
|
parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
|
||||||
|
parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
|
||||||
|
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
|
||||||
|
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (isStrided) {
|
||||||
|
if (parser->resolveOperands(strideInfo, indexType, result->operands))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that src/dst/tag operand counts match their map.numInputs.
|
||||||
|
if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
|
||||||
|
dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
|
||||||
|
tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
|
||||||
|
return parser->emitError(parser->getNameLoc(),
|
||||||
|
"memref operand count not equal to map.numInputs");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult AffineDmaStartOp::verify() {
|
||||||
|
if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||||
|
return emitOpError("expected DMA source to be of memref type");
|
||||||
|
if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||||
|
return emitOpError("expected DMA destination to be of memref type");
|
||||||
|
if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||||
|
return emitOpError("expected DMA tag to be of memref type");
|
||||||
|
|
||||||
|
// DMAs from different memory spaces supported.
|
||||||
|
if (getSrcMemorySpace() == getDstMemorySpace()) {
|
||||||
|
return emitOpError("DMA should be between different memory spaces");
|
||||||
|
}
|
||||||
|
unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
|
||||||
|
getDstMap().getNumInputs() +
|
||||||
|
getTagMap().getNumInputs();
|
||||||
|
if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
|
||||||
|
getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
|
||||||
|
return emitOpError("incorrect number of operands");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AffineDmaWaitOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
|
||||||
|
void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
|
||||||
|
Value *tagMemRef, AffineMap tagMap,
|
||||||
|
ArrayRef<Value *> tagIndices, Value *numElements) {
|
||||||
|
result->addOperands(tagMemRef);
|
||||||
|
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
|
||||||
|
result->addOperands(tagIndices);
|
||||||
|
result->addOperands(numElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AffineDmaWaitOp::print(OpAsmPrinter *p) {
|
||||||
|
*p << "affine.dma_wait " << *getTagMemRef() << '[';
|
||||||
|
SmallVector<Value *, 2> operands(getTagIndices());
|
||||||
|
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||||
|
*p << "], ";
|
||||||
|
p->printOperand(getNumElements());
|
||||||
|
*p << " : " << getTagMemRef()->getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse AffineDmaWaitOp.
|
||||||
|
// Eg:
|
||||||
|
// affine.dma_wait %tag[%index], %num_elements
|
||||||
|
// : memref<1 x i32, (d0) -> (d0), 4>
|
||||||
|
//
|
||||||
|
ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
|
||||||
|
OperationState *result) {
|
||||||
|
OpAsmParser::OperandType tagMemRefInfo;
|
||||||
|
AffineMapAttr tagMapAttr;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
|
||||||
|
Type type;
|
||||||
|
auto indexType = parser->getBuilder().getIndexType();
|
||||||
|
OpAsmParser::OperandType numElementsInfo;
|
||||||
|
|
||||||
|
// Parse tag memref, its map operands, and dma size.
|
||||||
|
if (parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
|
||||||
|
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
|
||||||
|
getTagMapAttrName(), result->attributes) ||
|
||||||
|
parser->parseRSquare() || parser->parseComma() ||
|
||||||
|
parser->parseOperand(numElementsInfo) || parser->parseColonType(type) ||
|
||||||
|
parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
|
||||||
|
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
|
||||||
|
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!type.isa<MemRefType>())
|
||||||
|
return parser->emitError(parser->getNameLoc(),
|
||||||
|
"expected tag to be of memref type");
|
||||||
|
|
||||||
|
if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
|
||||||
|
return parser->emitError(parser->getNameLoc(),
|
||||||
|
"tag memref operand count != to map.numInputs");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult AffineDmaWaitOp::verify() {
|
||||||
|
if (!getOperand(0)->getType().isa<MemRefType>())
|
||||||
|
return emitOpError("expected DMA tag to be of memref type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AffineForOp
|
// AffineForOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
// RUN: mlir-opt %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
|
||||||
|
|
||||||
|
// Test with loop IVs.
|
||||||
|
func @test0(%arg0 : index, %arg1 : index) {
|
||||||
|
%0 = alloc() : memref<100x100xf32>
|
||||||
|
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.for %i1 = 0 to 10 {
|
||||||
|
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64
|
||||||
|
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
|
||||||
|
|
||||||
|
// Test with loop IVs and optional stride arguments.
|
||||||
|
func @test1(%arg0 : index, %arg1 : index) {
|
||||||
|
%0 = alloc() : memref<100x100xf32>
|
||||||
|
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
%c128 = constant 128 : index
|
||||||
|
%c256 = constant 256 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.for %i1 = 0 to 10 {
|
||||||
|
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256
|
||||||
|
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1 + d2 + 5)
|
||||||
|
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0 + d1, d2)
|
||||||
|
|
||||||
|
// Test with loop IVs and symbols (without symbol keyword).
|
||||||
|
func @test2(%arg0 : index, %arg1 : index) {
|
||||||
|
%0 = alloc() : memref<100x100xf32>
|
||||||
|
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.for %i1 = 0 to 10 {
|
||||||
|
affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5],
|
||||||
|
%2[%c0], %c64
|
||||||
|
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
// CHECK: affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
|
||||||
|
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> (d0 + s0, d1)
|
||||||
|
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
|
||||||
|
|
||||||
|
// Test with loop IVs and symbols (with symbol keyword).
|
||||||
|
func @test3(%arg0 : index, %arg1 : index) {
|
||||||
|
%0 = alloc() : memref<100x100xf32>
|
||||||
|
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.for %i1 = 0 to 10 {
|
||||||
|
affine.dma_start %0[%i0 + symbol(%arg0), %i1],
|
||||||
|
%1[%i0, %i1 + symbol(%arg1) + 7],
|
||||||
|
%2[%i0 + %i1 + 11], %c64
|
||||||
|
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
// CHECK: affine.dma_start %0[%i0 + symbol(%arg0), %i1], %1[%i0, %i1 + symbol(%arg1) + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
|
||||||
|
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
|
||||||
|
// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
|
||||||
|
|
||||||
|
// Test with loop IVs, symbols and constants in nested affine expressions.
|
||||||
|
func @test4(%arg0 : index, %arg1 : index) {
|
||||||
|
%0 = alloc() : memref<100x100xf32>
|
||||||
|
%1 = alloc() : memref<100x100xf32, 2>
|
||||||
|
%2 = alloc() : memref<1xi32>
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
affine.for %i0 = 0 to 10 {
|
||||||
|
affine.for %i1 = 0 to 10 {
|
||||||
|
affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1],
|
||||||
|
%1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7],
|
||||||
|
%2[%i0 + %i1 + 11], %c64
|
||||||
|
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
|
||||||
|
// CHECK: affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1], %1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||||
|
// CHECK: affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue