forked from OSchip/llvm-project
Add new operations affine.dma_start and affine.dma_wait which take affine maps for indexing memrefs by construction.
These ops are analogues of the current standard ops dma_start/wait, with the exception that the memref operands are affine expressions of loop IVs and symbols (analogous to affine.load/store). The addition of these operations will enable changes to affine transformation and analysis passes which operate on memory dereferencing operations. PiperOrigin-RevId: 255658382
This commit is contained in:
parent
7b17f4e647
commit
6c68596aee
|
@ -90,6 +90,285 @@ public:
|
|||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
|
||||
/// from a source memref to a destination memref. The source and destination
|
||||
/// memref need not be of the same dimensionality, but need to have the same
|
||||
/// elemental type. The operands include the source and destination memref's
|
||||
/// each followed by its indices, size of the data transfer in terms of the
|
||||
/// number of elements (of the elemental type of the memref), a tag memref with
|
||||
/// its indices, and optionally at the end, a stride and a
|
||||
/// number_of_elements_per_stride arguments. The tag location is used by an
|
||||
/// AffineDmaWaitOp to check for completion. The indices of the source memref,
|
||||
/// destination memref, and the tag memref have the same restrictions as any
|
||||
/// affine.load/store. In particular, index for each memref dimension must be an
|
||||
/// affine expression of loop induction variables and symbols.
|
||||
/// The optional stride arguments should be of 'index' type, and specify a
|
||||
/// stride for the slower memory space (memory space with a lower memory space
|
||||
/// id), tranferring chunks of number_of_elements_per_stride every stride until
|
||||
/// %num_elements are transferred. Either both or no stride arguments should be
|
||||
/// specified. The value of 'num_elements' must be a multiple of
|
||||
/// 'number_of_elements_per_stride'.
|
||||
//
|
||||
// For example, a DmaStartOp operation that transfers 256 elements of a memref
|
||||
// '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory
|
||||
// space 1 at indices [%k + 7, %l], would be specified as follows:
|
||||
//
|
||||
// %num_elements = constant 256
|
||||
// %idx = constant 0 : index
|
||||
// %tag = alloc() : memref<1xi32, 4>
|
||||
// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
|
||||
// %num_elements :
|
||||
// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
|
||||
//
|
||||
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
||||
// transfer %num_elt_per_stride elements every %stride elements apart from
|
||||
// memory space 0 until %num_elements are transferred.
|
||||
//
|
||||
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
|
||||
// %stride, %num_elt_per_stride : ...
|
||||
//
|
||||
// TODO(mlir-team): add additional operands to allow source and destination
|
||||
// striding, and multiple stride levels (possibly using AffineMaps to specify
|
||||
// multiple levels of striding).
|
||||
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
|
||||
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
|
||||
OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
|
||||
AffineMap srcMap, ArrayRef<Value *> srcIndices,
|
||||
Value *destMemRef, AffineMap dstMap,
|
||||
ArrayRef<Value *> destIndices, Value *tagMemRef,
|
||||
AffineMap tagMap, ArrayRef<Value *> tagIndices,
|
||||
Value *numElements, Value *stride = nullptr,
|
||||
Value *elementsPerStride = nullptr);
|
||||
|
||||
/// Returns the operand index of the src memref.
|
||||
unsigned getSrcMemRefOperandIndex() { return 0; }
|
||||
|
||||
/// Returns the source MemRefType for this DMA operation.
|
||||
Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
|
||||
MemRefType getSrcMemRefType() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
|
||||
|
||||
/// Returns the affine map used to access the src memref.
|
||||
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
|
||||
AffineMapAttr getSrcMapAttr() {
|
||||
return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
/// Returns the source memref affine map indices for this DMA operation.
|
||||
operand_range getSrcIndices() {
|
||||
return {operand_begin() + getSrcMemRefOperandIndex() + 1,
|
||||
operand_begin() + getSrcMemRefOperandIndex() + 1 +
|
||||
getSrcMap().getNumInputs()};
|
||||
}
|
||||
|
||||
/// Returns the memory space of the src memref.
|
||||
unsigned getSrcMemorySpace() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
/// Returns the operand index of the dst memref.
|
||||
unsigned getDstMemRefOperandIndex() {
|
||||
return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
|
||||
}
|
||||
|
||||
/// Returns the destination MemRefType for this DMA operations.
|
||||
Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
|
||||
MemRefType getDstMemRefType() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the memory space of the src memref.
|
||||
unsigned getDstMemorySpace() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the dst memref.
|
||||
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
|
||||
AffineMapAttr getDstMapAttr() {
|
||||
return getAttr(getDstMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
/// Returns the destination memref indices for this DMA operation.
|
||||
operand_range getDstIndices() {
|
||||
return {operand_begin() + getDstMemRefOperandIndex() + 1,
|
||||
operand_begin() + getDstMemRefOperandIndex() + 1 +
|
||||
getDstMap().getNumInputs()};
|
||||
}
|
||||
|
||||
/// Returns the operand index of the tag memref.
|
||||
unsigned getTagMemRefOperandIndex() {
|
||||
return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
|
||||
}
|
||||
|
||||
/// Returns the Tag MemRef for this DMA operation.
|
||||
Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
|
||||
MemRefType getTagMemRefType() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the tag MemRefType.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the tag memref.
|
||||
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
|
||||
AffineMapAttr getTagMapAttr() {
|
||||
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
/// Returns the tag memref indices for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
return {operand_begin() + getTagMemRefOperandIndex() + 1,
|
||||
operand_begin() + getTagMemRefOperandIndex() + 1 +
|
||||
getTagMap().getNumInputs()};
|
||||
}
|
||||
|
||||
/// Returns the number of elements being transferred by this DMA operation.
|
||||
Value *getNumElements() {
|
||||
return getOperand(getTagMemRefOperandIndex() + 1 +
|
||||
getTagMap().getNumInputs());
|
||||
}
|
||||
|
||||
/// Returns the AffineMapAttr associated with 'memref'.
|
||||
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||
if (memref == getSrcMemRef())
|
||||
return {Identifier::get(getSrcMapAttrName(), getContext()),
|
||||
getSrcMapAttr()};
|
||||
else if (memref == getDstMemRef())
|
||||
return {Identifier::get(getDstMapAttrName(), getContext()),
|
||||
getDstMapAttr()};
|
||||
assert(memref == getTagMemRef() &&
|
||||
"DmaStartOp expected source, destination or tag memref");
|
||||
return {Identifier::get(getTagMapAttrName(), getContext()),
|
||||
getTagMapAttr()};
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
||||
bool isDestMemorySpaceFaster() {
|
||||
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
||||
bool isSrcMemorySpaceFaster() {
|
||||
// Assumes that a lower number is for a slower memory space.
|
||||
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||
unsigned getFasterMemPos() {
|
||||
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||
return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
|
||||
}
|
||||
|
||||
static StringRef getSrcMapAttrName() { return "src_map"; }
|
||||
static StringRef getDstMapAttrName() { return "dst_map"; }
|
||||
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||
|
||||
static StringRef getOperationName() { return "affine.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
|
||||
/// Returns true if this DMA operation is strided, returns false otherwise.
|
||||
bool isStrided() {
|
||||
return getNumOperands() !=
|
||||
getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
|
||||
}
|
||||
|
||||
/// Returns the stride value for this DMA operation.
|
||||
Value *getStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1 - 1);
|
||||
}
|
||||
|
||||
/// Returns the number of elements to transfer per stride for this DMA op.
|
||||
Value *getNumElementsPerStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1);
|
||||
}
|
||||
};
|
||||
|
||||
/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
|
||||
/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
|
||||
/// an index with the same restrictions as any load/store index. In particular,
|
||||
/// index for each memref dimension must be an affine expression of loop
|
||||
/// induction variables and symbols. %num_elements is the number of elements
|
||||
/// associated with the DMA operation. For example:
|
||||
//
|
||||
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
|
||||
// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
|
||||
// ...
|
||||
// ...
|
||||
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
|
||||
//
|
||||
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
|
||||
OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
|
||||
AffineMap tagMap, ArrayRef<Value *> tagIndices,
|
||||
Value *numElements);
|
||||
|
||||
static StringRef getOperationName() { return "affine.dma_wait"; }
|
||||
|
||||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||
Value *getTagMemRef() { return getOperand(0); }
|
||||
MemRefType getTagMemRefType() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the tag memref.
|
||||
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
|
||||
AffineMapAttr getTagMapAttr() {
|
||||
return getAttr(getTagMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
return {operand_begin() + 1,
|
||||
operand_begin() + 1 + getTagMap().getNumInputs()};
|
||||
}
|
||||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the AffineMapAttr associated with 'memref'.
|
||||
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||
assert(memref == getTagMemRef());
|
||||
return {Identifier::get(getTagMapAttrName(), getContext()),
|
||||
getTagMapAttr()};
|
||||
}
|
||||
|
||||
/// Returns the number of elements transferred in the associated DMA op.
|
||||
Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
|
||||
|
||||
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// The "affine.for" operation represents an affine loop nest, defining an SSA
|
||||
/// value for its induction variable. It has one region capturing the loop body.
|
||||
/// The induction variable is represented as a argument of this region. This SSA
|
||||
|
@ -382,10 +661,19 @@ public:
|
|||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
|
||||
|
||||
/// Returns the affine map used to index the memref for this operation.
|
||||
AffineMap getAffineMap() {
|
||||
return getAttrOfType<AffineMapAttr>("map").getValue();
|
||||
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||
AffineMapAttr getAffineMapAttr() {
|
||||
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
/// Returns the AffineMapAttr associated with 'memref'.
|
||||
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||
assert(memref == getMemRef());
|
||||
return {Identifier::get(getMapAttrName(), getContext()),
|
||||
getAffineMapAttr()};
|
||||
}
|
||||
|
||||
static StringRef getMapAttrName() { return "map"; }
|
||||
static StringRef getOperationName() { return "affine.load"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
@ -435,10 +723,19 @@ public:
|
|||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
|
||||
|
||||
/// Returns the affine map used to index the memref for this operation.
|
||||
AffineMap getAffineMap() {
|
||||
return getAttrOfType<AffineMapAttr>("map").getValue();
|
||||
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||
AffineMapAttr getAffineMapAttr() {
|
||||
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
|
||||
}
|
||||
|
||||
/// Returns the AffineMapAttr associated with 'memref'.
|
||||
NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
|
||||
assert(memref == getMemRef());
|
||||
return {Identifier::get(getMapAttrName(), getContext()),
|
||||
getAffineMapAttr()};
|
||||
}
|
||||
|
||||
static StringRef getMapAttrName() { return "map"; }
|
||||
static StringRef getOperationName() { return "affine.store"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
|
|
@ -37,8 +37,8 @@ using llvm::dbgs;
|
|||
|
||||
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineLoadOp,
|
||||
AffineStoreOp, AffineTerminatorOp>();
|
||||
addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineForOp,
|
||||
AffineIfOp, AffineLoadOp, AffineStoreOp, AffineTerminatorOp>();
|
||||
}
|
||||
|
||||
/// A utility function to check if a value is defined at the top level of a
|
||||
|
@ -696,6 +696,220 @@ void AffineApplyOp::getCanonicalizationPatterns(
|
|||
results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineDmaStartOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
|
||||
void AffineDmaStartOp::build(Builder *builder, OperationState *result,
|
||||
Value *srcMemRef, AffineMap srcMap,
|
||||
ArrayRef<Value *> srcIndices, Value *destMemRef,
|
||||
AffineMap dstMap, ArrayRef<Value *> destIndices,
|
||||
Value *tagMemRef, AffineMap tagMap,
|
||||
ArrayRef<Value *> tagIndices, Value *numElements,
|
||||
Value *stride, Value *elementsPerStride) {
|
||||
result->addOperands(srcMemRef);
|
||||
result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
|
||||
result->addOperands(srcIndices);
|
||||
result->addOperands(destMemRef);
|
||||
result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
|
||||
result->addOperands(destIndices);
|
||||
result->addOperands(tagMemRef);
|
||||
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
|
||||
result->addOperands(tagIndices);
|
||||
result->addOperands(numElements);
|
||||
if (stride) {
|
||||
result->addOperands({stride, elementsPerStride});
|
||||
}
|
||||
}
|
||||
|
||||
void AffineDmaStartOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.dma_start " << *getSrcMemRef() << '[';
|
||||
SmallVector<Value *, 8> operands(getSrcIndices());
|
||||
p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
|
||||
*p << "], " << *getDstMemRef() << '[';
|
||||
operands.assign(getDstIndices().begin(), getDstIndices().end());
|
||||
p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
|
||||
*p << "], " << *getTagMemRef() << '[';
|
||||
operands.assign(getTagIndices().begin(), getTagIndices().end());
|
||||
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
*p << "], " << *getNumElements();
|
||||
if (isStrided()) {
|
||||
*p << ", " << *getStride();
|
||||
*p << ", " << *getNumElementsPerStride();
|
||||
}
|
||||
*p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
|
||||
<< getTagMemRefType();
|
||||
}
|
||||
|
||||
// Parse AffineDmaStartOp.
|
||||
// Ex:
|
||||
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
|
||||
// %stride, %num_elt_per_stride
|
||||
// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
|
||||
//
|
||||
ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType srcMemRefInfo;
|
||||
AffineMapAttr srcMapAttr;
|
||||
SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
|
||||
OpAsmParser::OperandType dstMemRefInfo;
|
||||
AffineMapAttr dstMapAttr;
|
||||
SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
|
||||
OpAsmParser::OperandType tagMemRefInfo;
|
||||
AffineMapAttr tagMapAttr;
|
||||
SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
|
||||
OpAsmParser::OperandType numElementsInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 2> strideInfo;
|
||||
|
||||
SmallVector<Type, 3> types;
|
||||
auto indexType = parser->getBuilder().getIndexType();
|
||||
|
||||
// Parse and resolve the following list of operands:
|
||||
// *) dst memref followed by its affine maps operands (in square brackets).
|
||||
// *) src memref followed by its affine map operands (in square brackets).
|
||||
// *) tag memref followed by its affine map operands (in square brackets).
|
||||
// *) number of elements transferred by DMA operation.
|
||||
if (parser->parseOperand(srcMemRefInfo) || parser->parseLSquare() ||
|
||||
parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
|
||||
getSrcMapAttrName(), result->attributes) ||
|
||||
parser->parseRSquare() || parser->parseComma() ||
|
||||
parser->parseOperand(dstMemRefInfo) || parser->parseLSquare() ||
|
||||
parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
|
||||
getDstMapAttrName(), result->attributes) ||
|
||||
parser->parseRSquare() || parser->parseComma() ||
|
||||
parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
|
||||
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
|
||||
getTagMapAttrName(), result->attributes) ||
|
||||
parser->parseRSquare() || parser->parseComma() ||
|
||||
parser->parseOperand(numElementsInfo))
|
||||
return failure();
|
||||
|
||||
// Parse optional stride and elements per stride.
|
||||
if (parser->parseTrailingOperandList(strideInfo)) {
|
||||
return failure();
|
||||
}
|
||||
if (!strideInfo.empty() && strideInfo.size() != 2) {
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected two stride related operands");
|
||||
}
|
||||
bool isStrided = strideInfo.size() == 2;
|
||||
|
||||
if (parser->parseColonTypeList(types))
|
||||
return failure();
|
||||
|
||||
if (types.size() != 3)
|
||||
return parser->emitError(parser->getNameLoc(), "expected three types");
|
||||
|
||||
if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
|
||||
parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
|
||||
parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
|
||||
parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
|
||||
parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
|
||||
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
|
||||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||
return failure();
|
||||
|
||||
if (isStrided) {
|
||||
if (parser->resolveOperands(strideInfo, indexType, result->operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that src/dst/tag operand counts match their map.numInputs.
|
||||
if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
|
||||
dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
|
||||
tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"memref operand count not equal to map.numInputs");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AffineDmaStartOp::verify() {
|
||||
if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA source to be of memref type");
|
||||
if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA destination to be of memref type");
|
||||
if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA tag to be of memref type");
|
||||
|
||||
// DMAs from different memory spaces supported.
|
||||
if (getSrcMemorySpace() == getDstMemorySpace()) {
|
||||
return emitOpError("DMA should be between different memory spaces");
|
||||
}
|
||||
unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
|
||||
getDstMap().getNumInputs() +
|
||||
getTagMap().getNumInputs();
|
||||
if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
|
||||
getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
|
||||
return emitOpError("incorrect number of operands");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineDmaWaitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
|
||||
void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
|
||||
Value *tagMemRef, AffineMap tagMap,
|
||||
ArrayRef<Value *> tagIndices, Value *numElements) {
|
||||
result->addOperands(tagMemRef);
|
||||
result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
|
||||
result->addOperands(tagIndices);
|
||||
result->addOperands(numElements);
|
||||
}
|
||||
|
||||
void AffineDmaWaitOp::print(OpAsmPrinter *p) {
|
||||
*p << "affine.dma_wait " << *getTagMemRef() << '[';
|
||||
SmallVector<Value *, 2> operands(getTagIndices());
|
||||
p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
*p << "], ";
|
||||
p->printOperand(getNumElements());
|
||||
*p << " : " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse AffineDmaWaitOp.
|
||||
// Eg:
|
||||
// affine.dma_wait %tag[%index], %num_elements
|
||||
// : memref<1 x i32, (d0) -> (d0), 4>
|
||||
//
|
||||
ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType tagMemRefInfo;
|
||||
AffineMapAttr tagMapAttr;
|
||||
SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
|
||||
Type type;
|
||||
auto indexType = parser->getBuilder().getIndexType();
|
||||
OpAsmParser::OperandType numElementsInfo;
|
||||
|
||||
// Parse tag memref, its map operands, and dma size.
|
||||
if (parser->parseOperand(tagMemRefInfo) || parser->parseLSquare() ||
|
||||
parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
|
||||
getTagMapAttrName(), result->attributes) ||
|
||||
parser->parseRSquare() || parser->parseComma() ||
|
||||
parser->parseOperand(numElementsInfo) || parser->parseColonType(type) ||
|
||||
parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
|
||||
parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
|
||||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||
return failure();
|
||||
|
||||
if (!type.isa<MemRefType>())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected tag to be of memref type");
|
||||
|
||||
if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"tag memref operand count != to map.numInputs");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AffineDmaWaitOp::verify() {
|
||||
if (!getOperand(0)->getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA tag to be of memref type");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineForOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
// RUN: mlir-opt %s -split-input-file | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
|
||||
|
||||
// Test with loop IVs.
|
||||
func @test0(%arg0 : index, %arg1 : index) {
|
||||
%0 = alloc() : memref<100x100xf32>
|
||||
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c0 = constant 0 : index
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64
|
||||
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
|
||||
|
||||
// Test with loop IVs and optional stride arguments.
|
||||
func @test1(%arg0 : index, %arg1 : index) {
|
||||
%0 = alloc() : memref<100x100xf32>
|
||||
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c0 = constant 0 : index
|
||||
%c64 = constant 64 : index
|
||||
%c128 = constant 128 : index
|
||||
%c256 = constant 256 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256
|
||||
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
// CHECK: affine.dma_start %0[%i0, %i1], %1[%i0, %i1], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1 + d2 + 5)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0 + d1, d2)
|
||||
|
||||
// Test with loop IVs and symbols (without symbol keyword).
|
||||
func @test2(%arg0 : index, %arg1 : index) {
|
||||
%0 = alloc() : memref<100x100xf32>
|
||||
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c0 = constant 0 : index
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5],
|
||||
%2[%c0], %c64
|
||||
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
// CHECK: affine.dma_start %0[%i0 + %arg0, %i1], %1[%i0, %i1 + %arg1 + 5], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> (d0 + s0, d1)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
|
||||
|
||||
// Test with loop IVs and symbols (with symbol keyword).
|
||||
func @test3(%arg0 : index, %arg1 : index) {
|
||||
%0 = alloc() : memref<100x100xf32>
|
||||
%1 = alloc() : memref<100x100xf32, (d0, d1) -> (d0, d1), 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c0 = constant 0 : index
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.dma_start %0[%i0 + symbol(%arg0), %i1],
|
||||
%1[%i0, %i1 + symbol(%arg1) + 7],
|
||||
%2[%i0 + %i1 + 11], %c64
|
||||
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
// CHECK: affine.dma_start %0[%i0 + symbol(%arg0), %i1], %1[%i0, %i1 + symbol(%arg1) + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
// CHECK: affine.dma_wait %2[%c0], %c64 : memref<1xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
|
||||
// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1) -> (d0 + d1 + 11)
|
||||
|
||||
// Test with loop IVs, symbols and constants in nested affine expressions.
|
||||
func @test4(%arg0 : index, %arg1 : index) {
|
||||
%0 = alloc() : memref<100x100xf32>
|
||||
%1 = alloc() : memref<100x100xf32, 2>
|
||||
%2 = alloc() : memref<1xi32>
|
||||
%c64 = constant 64 : index
|
||||
affine.for %i0 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1],
|
||||
%1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7],
|
||||
%2[%i0 + %i1 + 11], %c64
|
||||
: memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
|
||||
// CHECK: affine.dma_start %0[(%i0 + symbol(%arg0)) floordiv 3, %i1], %1[%i0, (%i1 + symbol(%arg1)) mod 9 + 7], %2[%i0 + %i1 + 11], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
|
||||
// CHECK: affine.dma_wait %2[%i0 + %i1 + 11], %c64 : memref<1xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue