[mlir[[vector] Extend Transfer read/write ops to support tensor types.

Transfer_ops can now work on both buffers and tensor. Right now, lowering of
the tensor case is not supported yet.

Differential Revision: https://reviews.llvm.org/D93500
This commit is contained in:
Thomas Raoux 2020-12-17 16:26:07 -08:00
parent 9a93f95fce
commit 26c8f9081b
16 changed files with 304 additions and 189 deletions

View File

@ -126,7 +126,7 @@ namespace impl {
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType);
} // namespace impl
} // end namespace vector

View File

@ -1056,7 +1056,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
OptionalAttr<BoolArrayAttr>:$masked)>,
Results<(outs AnyVector:$vector)> {
@ -1065,15 +1065,16 @@ def Vector_TransferReadOp :
let description = [{
The `vector.transfer_read` op performs a read from a slice within a
[MemRef](../LangRef.md#memref-type) supplied as its first operand
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
[MemRef](../LangRef.md#memref-type) or a Ranked
[Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
[vector](../LangRef.md#vector-type) of the same base elemental type.
A memref operand with vector element type, must have its vector element
type match a suffix (shape and element type) of the vector (e.g.
A memref/tensor operand with vector element type, must have its vector
element type match a suffix (shape and element type) of the vector (e.g.
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `2 .. 1 + rank(memref)`.
The slice is further defined by a full-rank index within the MemRef/Tensor,
supplied as the operands `2 .. 1 + rank(memref/tensor)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@ -1084,8 +1085,9 @@ def Vector_TransferReadOp :
The size of the slice is specified by the size of the vector, given as the
return type.
An `ssa-value` of the same elemental type as the MemRef is provided as the
last operand to specify padding in the case of out-of-bounds accesses.
An `ssa-value` of the same elemental type as the MemRef/Tensor is provided
as the last operand to specify padding in the case of out-of-bounds
accesses.
An optional boolean array attribute is provided to specify which dimensions
of the transfer need masking. When a dimension is specified as not requiring
@ -1196,17 +1198,22 @@ def Vector_TransferReadOp :
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
{permutation_map = (d0, d1)->(d0, d1)}
: memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// Read from a tensor with vector element type.
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
{permutation_map = (d0, d1)->(d0, d1)}
: tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}];
let builders = [
// Builder that sets padding to zero.
OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "AffineMap":$permutationMap,
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>
];
@ -1217,26 +1224,29 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
]>,
Arguments<(ins AnyVector:$vector, AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
OptionalAttr<BoolArrayAttr>:$masked)> {
OptionalAttr<BoolArrayAttr>:$masked)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
let description = [{
The `vector.transfer_write` op performs a write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
slice within a [MemRef](../LangRef.md#memref-type) of the same base
elemental type, supplied as its second operand.
slice within a [MemRef](../LangRef.md#memref-type) or a Ranked
[Tensor](../LangRef.md#tensor-type) of the same base elemental type,
supplied as its second operand.
A vector memref operand must have its vector element type match a suffix
(shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
vector<1x1x4x3xf32>).
A vector memref/tensor operand must have its vector element type match a
suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a
new tensor of the same type.
The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `3 .. 2 + rank(memref)`.
The slice is further defined by a full-rank index within the MemRef/Tensor,
supplied as the operands `3 .. 2 + rank(memref/tensor)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@ -1280,15 +1290,24 @@ def Vector_TransferWriteOp :
vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
: vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// return a tensor where the vector is inserted into the source tensor.
%5 = vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
: vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
```
}];
let builders = [
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap)>,
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMapAttr":$permutationMap, "ArrayAttr":$masked)>,
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap, "ArrayAttr":$masked)>,
];
let hasFolder = 1;

View File

@ -20,9 +20,9 @@ class AffineApplyOp;
class AffineForOp;
class AffineMap;
class Location;
class MemRefType;
class OpBuilder;
class Operation;
class ShapedType;
class Value;
class VectorType;
class VectorTransferOpInterface;
@ -157,7 +157,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType);
/// Return true if we can prove that the transfer operations access disjoint

View File

@ -47,7 +47,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
let description = [{
Encodes properties of an operation on vectors that can be unrolled.
Encodes properties of a transfer read or write operation.
}];
let cppNamespace = "::mlir";
@ -83,11 +83,11 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
}]
>,
InterfaceMethod<
/*desc=*/"Return the memref operand.",
/*desc=*/"Return the memref or ranked tensor operand.",
/*retTy=*/"Value",
/*methodName=*/"memref",
/*methodName=*/"source",
/*args=*/(ins),
/*methodBody=*/"return $_op.memref();"
/*methodBody=*/"return $_op.source();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@ -123,13 +123,13 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*defaultImplementation=*/
>,
InterfaceMethod<
/*desc=*/"Return the MemRefType.",
/*retTy=*/"MemRefType",
/*methodName=*/"getMemRefType",
/*desc=*/"Return the ShapedType.",
/*retTy=*/"ShapedType",
/*methodName=*/"getShapedType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
"return $_op.memref().getType().template cast<MemRefType>();"
"return $_op.source().getType().template cast<ShapedType>();"
>,
InterfaceMethod<
/*desc=*/"Return the VectorType.",
@ -152,14 +152,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
"return $_op.permutation_map().getNumResults();"
>,
InterfaceMethod<
/*desc=*/[{ Return the number of leading memref dimensions that do not
/*desc=*/[{ Return the number of leading shaped dimensions that do not
participate in the permutation map.}],
/*retTy=*/"unsigned",
/*methodName=*/"getLeadingMemRefRank",
/*methodName=*/"getLeadingShapedRank",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
"return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
"return $_op.getShapedType().getRank() - $_op.getTransferRank();"
>,
InterfaceMethod<
/*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
@ -178,8 +178,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*desc=*/[{
Helper function to account for the fact that `permutationMap` results and
`op.indices` sizes may not match and may not be aligned. The first
`getLeadingMemRefRank()` indices may just be indexed and not transferred
from/into the vector.
`getLeadingShapedRank()` indices may just be indexed and not
transferred from/into the vector.
For example:
```
vector.transfer %0[%i, %j, %k, %c0] :
@ -195,7 +195,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
for (int64_t resultIdx = 0,
indicesIdx = $_op.getLeadingMemRefRank(),
indicesIdx = $_op.getLeadingShapedRank(),
eResult = $_op.getTransferRank();
resultIdx < eResult;
++resultIdx, ++indicesIdx)

View File

@ -22,6 +22,17 @@
using namespace mlir;
/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(LoadOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
static Value getMemRefOperand(StoreOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
}
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
@ -141,7 +152,7 @@ template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
@ -162,7 +173,8 @@ template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}

View File

@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
// Helper that returns data layout alignment of an operation with memref.
template <typename T>
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
unsigned &align) {
Type elementTy =
typeConverter.convertType(op.getMemRefType().getElementType());
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
Type elementTy = typeConverter.convertType(memrefType.getElementType());
if (!elementTy)
return failure();
@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
return success();
@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
return failure();
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
unsigned align;
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
@ -345,7 +347,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
align)))
return failure();
auto vtype = typeConverter->convertType(load.getResultVectorType());
@ -375,7 +378,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
align)))
return failure();
auto vtype = typeConverter->convertType(store.getValueVectorType());
@ -405,7 +409,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
align)))
return failure();
// Get index ptrs.
@ -438,7 +443,8 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
align)))
return failure();
// Get index ptrs.
@ -1182,8 +1188,11 @@ public:
xferOp.getVectorType().getRank(),
xferOp->getContext()))
return failure();
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// Only contiguous source tensors supported atm.
auto strides = computeContiguousStrides(xferOp.getMemRefType());
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
@ -1192,10 +1201,9 @@ public:
};
Location loc = xferOp->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
memRefType.getElementType().dyn_cast<VectorType>()) {
memRefType.getElementType().template dyn_cast<VectorType>()) {
// Memref has vector element type.
if (memrefVectorElementType.getElementType() !=
xferOp.getVectorType().getElementType())
@ -1222,7 +1230,7 @@ public:
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
@ -1248,7 +1256,7 @@ public:
unsigned vecWidth = vecTy.getVectorNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);

View File

@ -89,7 +89,9 @@ public:
return failure();
// Obtain dataPtr and elementType from the memref.
MemRefType memRefType = xferOp.getMemRefType();
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
// In case of 3(LDS): fall back to vector->llvm pass
// In case of 5(VGPR): wrong
@ -101,7 +103,7 @@ public:
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.

View File

@ -107,7 +107,7 @@ public:
// TODO: when we go to k > 1-D vectors adapt minorRank.
minorRank = 1;
majorRank = vectorType.getRank() - minorRank;
leadingRank = xferOp.getLeadingMemRefRank();
leadingRank = xferOp.getLeadingShapedRank();
majorVectorType =
VectorType::get(vectorType.getShape().take_front(majorRank),
vectorType.getElementType());
@ -115,9 +115,9 @@ public:
VectorType::get(vectorType.getShape().take_back(minorRank),
vectorType.getElementType());
/// Memref of minor vector type is used for individual transfers.
memRefMinorVectorType =
MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
xferOp.getMemRefType().getMemorySpace());
memRefMinorVectorType = MemRefType::get(
majorVectorType.getShape(), minorVectorType, {},
xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
}
LogicalResult doReplace();
@ -155,7 +155,7 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(
const MemRefBoundsCapture &)>
loopBodyBuilder) {
/// Loop nest operates on the major dimensions
MemRefBoundsCapture memrefBoundsCapture(xferOp.memref());
MemRefBoundsCapture memrefBoundsCapture(xferOp.source());
if (options.unroll) {
auto shape = majorVectorType.getShape();
@ -272,9 +272,9 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
Value memref = xferOp.memref();
Value memref = xferOp.source();
auto map =
getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
@ -379,13 +379,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
else
result = std_load(alloc, majorIvs);
auto map =
getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({false});
}
vector_transfer_write(result, xferOp.memref(), indexing,
vector_transfer_write(result, xferOp.source(), indexing,
AffineMapAttr::get(map), masked);
};
@ -422,7 +422,7 @@ template <typename TransferOpTy>
static int computeCoalescedIndex(TransferOpTy transfer) {
// rank of the remote memory access, coalescing behavior occurs on the
// innermost memory dimension.
auto remoteRank = transfer.getMemRefType().getRank();
auto remoteRank = transfer.getShapedType().getRank();
// Iterate over the results expressions of the permutation map to determine
// the loop order for creating pointwise copies between remote and local
// memories.
@ -536,13 +536,14 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
using namespace mlir::edsc::op;
TransferReadOp transfer = cast<TransferReadOp>(op);
auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// Fall back to a loop if the fastest varying stride is not 1 or it is
// permuted.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
getStridesAndOffset(transfer.getMemRefType(), strides, offset);
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (succeeded(successStrides) && strides.back() == 1 &&
transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
@ -557,8 +558,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// Conservative lowering to scalar load / stores.
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
StdIndexedValue remote(transfer.memref());
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
StdIndexedValue remote(transfer.source());
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
int coalescedIdx = computeCoalescedIndex(transfer);
// Swap the vectorBoundsCapture which will reorder loop bounds.
@ -621,13 +622,15 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
using namespace edsc::op;
TransferWriteOp transfer = cast<TransferWriteOp>(op);
auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// Fall back to a loop if the fastest varying stride is not 1 or it is
// permuted.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
getStridesAndOffset(transfer.getMemRefType(), strides, offset);
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (succeeded(successStrides) && strides.back() == 1 &&
transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
@ -641,8 +644,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
StdIndexedValue remote(transfer.memref());
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
StdIndexedValue remote(transfer.source());
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
Value vectorValue(transfer.vector());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
int coalescedIdx = computeCoalescedIndex(transfer);

View File

@ -111,7 +111,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
if (!candidateWrite || candidateWrite.source() != transferRead.source())
continue;
transferWrite = candidateWrite;
}
@ -142,7 +142,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
for (auto &use : transferRead.memref().getUses()) {
for (auto &use : transferRead.source().getUses()) {
if (!dom.properlyDominates(loop, use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||

View File

@ -411,7 +411,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `view`.
Value viewOrAlloc = xferOp.memref();
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();
@ -487,7 +487,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `viewOrAlloc`.
Value viewOrAlloc = xferOp.memref();
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();

View File

@ -1890,41 +1890,43 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}
static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
VectorType vectorType,
AffineMap permutationMap,
ArrayAttr optionalMasked) {
auto memrefElementType = memrefType.getElementType();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
memrefVectorElementType.getShape().back();
if (!shapedType.isa<MemRefType, RankedTensorType>())
return op->emitOpError(
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
// Memref or tensor has vector element type.
unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
vectorElementType.getShape().back();
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
if (resultVecSize % memrefVecSize != 0)
if (resultVecSize % sourceVecSize != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
"multiple of the bitwidth of the minor 1-D vector of the memref");
"multiple of the bitwidth of the minor 1-D vector of the source");
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
unsigned sourceVecEltRank = vectorElementType.getRank();
unsigned resultVecRank = vectorType.getRank();
if (memrefVecEltRank > resultVecRank)
if (sourceVecEltRank > resultVecRank)
return op->emitOpError(
"requires memref vector element and vector result ranks to match.");
unsigned rankOffset = resultVecRank - memrefVecEltRank;
"requires source vector element and vector result ranks to match.");
unsigned rankOffset = resultVecRank - sourceVecEltRank;
// Check that permutation map results match 'rankOffset' of vector type.
if (permutationMap.getNumResults() != rankOffset)
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
} else {
// Memref has scalar element type.
// Memref or tensor has scalar element type.
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
"multiple of the bitwidth of the memref element type");
"multiple of the bitwidth of the source element type");
// Check that permutation map results match rank of vector type.
if (permutationMap.getNumResults() != vectorType.getRank())
@ -1934,9 +1936,9 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
if (permutationMap.getNumInputs() != memrefType.getRank())
if (permutationMap.getNumInputs() != shapedType.getRank())
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
"same rank as the source type");
if (optionalMasked) {
if (permutationMap.getNumResults() !=
@ -1978,7 +1980,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 2> elidedAttrs;
if (op.permutation_map() ==
getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideMasked = true;
if (auto maybeMasked = op.masked()) {
@ -1995,21 +1997,21 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
}
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
p << op.getOperationName() << " " << op.source() << "[" << op.indices()
<< "], " << op.padding();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
p << " : " << op.getShapedType() << ", " << op.getVectorType();
}
static ParseResult parseTransferReadOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc typesLoc;
OpAsmParser::OperandType memrefInfo;
OpAsmParser::OperandType sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
OpAsmParser::OperandType paddingInfo;
SmallVector<Type, 2> types;
// Parsing with support for paddingValue.
if (parser.parseOperand(memrefInfo) ||
if (parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseComma() || parser.parseOperand(paddingInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
@ -2018,48 +2020,48 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = parser.getBuilder().getIndexType();
MemRefType memRefType = types[0].dyn_cast<MemRefType>();
if (!memRefType)
return parser.emitError(typesLoc, "requires memref type");
auto shapedType = types[0].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
VectorType vectorType = types[1].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
parser.resolveOperand(paddingInfo, memRefType.getElementType(),
parser.resolveOperand(paddingInfo, shapedType.getElementType(),
result.operands) ||
parser.addTypeToList(vectorType, result.types));
}
static LogicalResult verify(TransferReadOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
// Consistency of elemental types in source and vector.
ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto paddingType = op.padding().getType();
auto permutationMap = op.permutation_map();
auto memrefElementType = memrefType.getElementType();
auto sourceElementType = shapedType.getElementType();
if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
// Check that 'memrefVectorElementType' and 'paddingType' types match.
if (memrefVectorElementType != paddingType)
if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
// Source has vector element type.
// Check that 'sourceVectorElementType' and 'paddingType' types match.
if (sourceVectorElementType != paddingType)
return op.emitOpError(
"requires memref element type and padding type to match.");
"requires source element type and padding type to match.");
} else {
// Check that 'paddingType' is valid to store in a vector type.
@ -2067,9 +2069,9 @@ static LogicalResult verify(TransferReadOp op) {
return op.emitOpError("requires valid padding vector elemental type");
// Check that padding type and vector element types match.
if (paddingType != memrefElementType)
if (paddingType != sourceElementType)
return op.emitOpError(
"requires formal padding and memref of the same elemental type");
"requires formal padding and source of the same elemental type");
}
return verifyPermutationMap(permutationMap,
@ -2096,18 +2098,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
// `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
if (op.getMemRefType().isDynamicDim(indicesIdx))
// `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
if (op.getShapedType().isDynamicDim(indicesIdx))
return false;
Value index = op.indices()[indicesIdx];
auto cstOp = index.getDefiningOp<ConstantIndexOp>();
if (!cstOp)
return false;
int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
return cstOp.getValue() + vectorSize <= memrefSize;
return cstOp.getValue() + vectorSize <= sourceSize;
}
template <typename TransferOp>
@ -2159,33 +2161,51 @@ Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value memref, ValueRange indices,
Value vector, Value source, ValueRange indices,
ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
memref.getType().cast<MemRefType>(), vectorType);
source.getType().cast<MemRefType>(), vectorType);
if (maybeMasked.empty())
return build(builder, result, vector, memref, indices, permMap,
return build(builder, result, vector, source, indices, permMap,
ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
build(builder, result, vector, source, indices, permMap, maskedArrayAttr);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value memref, ValueRange indices,
Value vector, Value source, ValueRange indices,
AffineMap permutationMap) {
build(builder, result, vector, memref, indices, permutationMap,
build(builder, result, vector, source, indices, permutationMap,
/*maybeMasked=*/ArrayAttr());
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMapAttr permutationMap,
/*optional*/ ArrayAttr masked) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
masked);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMap permutationMap,
/*optional*/ ArrayAttr masked) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
masked);
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc typesLoc;
OpAsmParser::OperandType vectorInfo, memrefInfo;
OpAsmParser::OperandType vectorInfo, sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<Type, 2> types;
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
@ -2196,38 +2216,40 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
VectorType vectorType = types[0].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
MemRefType memRefType = types[1].dyn_cast<MemRefType>();
if (!memRefType)
return parser.emitError(typesLoc, "requires memref type");
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands));
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
(shapedType.isa<RankedTensorType>() &&
parser.addTypeToList(shapedType, result.types)));
}
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
<< op.indices() << "]";
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
p << " : " << op.getVectorType() << ", " << op.getShapedType();
}
static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto permutationMap = op.permutation_map();
if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (llvm::size(op.indices()) != shapedType.getRank())
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();

View File

@ -94,7 +94,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> reads;
Operation *firstOverwriteCandidate = nullptr;
for (auto *user : write.memref().getUsers()) {
for (auto *user : write.source().getUsers()) {
if (user == write.getOperation())
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
for (Operation *user : read.memref().getUsers()) {
for (Operation *user : read.source().getUsers()) {
if (isa<vector::TransferReadOp>(user))
continue;
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {

View File

@ -597,7 +597,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
Location loc = readOp.getLoc();
auto memrefElementType =
readOp.memref().getType().cast<MemRefType>().getElementType();
readOp.source().getType().cast<MemRefType>().getElementType();
auto tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
int64_t numSlices = tupleType.size();
@ -612,7 +612,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
loc, sliceVectorType, readOp.memref(), sliceIndices,
loc, sliceVectorType, readOp.source(), sliceIndices,
readOp.permutation_map(), readOp.padding(),
readOp.masked() ? *readOp.masked() : ArrayAttr());
};
@ -644,14 +644,14 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
Value tuple = builder.create<vector::ExtractSlicesOp>(
loc, tupleType, writeOp.vector(), targetShape, strides);
auto memrefElementType =
writeOp.memref().getType().cast<MemRefType>().getElementType();
writeOp.source().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(writeOp.indices().begin(),
writeOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
auto element = builder.create<vector::TupleGetOp>(
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
builder.create<vector::TransferWriteOp>(
loc, element.getResult(), writeOp.memref(), sliceIndices,
loc, element.getResult(), writeOp.source(), sliceIndices,
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
};
@ -760,7 +760,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
Location loc = xferWriteOp.getLoc();
auto memrefElementType =
xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
xferWriteOp.source().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
xferWriteOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
@ -768,7 +768,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
xferWriteOp.permutation_map(),
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
};
@ -2142,7 +2142,7 @@ static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
// Fold or create the check that `index + vector_size` <= `memref_size`.
Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
Value cond =
createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
if (!cond)
return;
// Conjunction over all dims for which we are in-bounds.
@ -2207,23 +2207,23 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
}
/// Operates under a scoped context to build the intersection between the
/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/differences should be a proper std op.
static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
Value alloc) {
using namespace edsc::intrinsics;
int64_t memrefRank = xferOp.getMemRefType().getRank();
int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
"Expected memref rank to match the alloc rank");
Value one = std_constant_index(1);
ValueRange leadingIndices =
xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
xferOp.indices().take_front(xferOp.getLeadingShapedRank());
SmallVector<Value, 4> sizes;
sizes.append(leadingIndices.begin(), leadingIndices.end());
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
Value dimAlloc = std_dim(alloc, resultIdx);
Value index = xferOp.indices()[indicesIdx];
AffineExpr i, j, k;
@ -2235,7 +2235,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
ValueRange{dimMemRef, index, dimAlloc});
sizes.push_back(affineMin);
});
return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
return std_sub_view(xferOp.source(), xferOp.indices(), sizes,
SmallVector<Value, 4>(memrefRank, one));
}
@ -2263,12 +2263,12 @@ static scf::IfOp createScopedFullPartialLinalgCopy(
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
Value memref = xferOp.memref();
Value memref = xferOp.source();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
if (compatibleMemRefType != xferOp.getShapedType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@ -2317,12 +2317,12 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
Value memref = xferOp.memref();
Value memref = xferOp.source();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
if (compatibleMemRefType != xferOp.getShapedType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@ -2376,7 +2376,7 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
@ -2404,8 +2404,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
return failure();
OpBuilder::InsertionGuard guard(b);
if (xferOp.memref().getDefiningOp())
b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
if (Operation *sourceOp = xferOp.source().getDefiningOp())
b.setInsertionPointAfter(sourceOp);
else
b.setInsertionPoint(xferOp);
ScopedContext scope(b, xferOp.getLoc());
@ -2426,8 +2426,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
b.getI64IntegerAttr(32));
}
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
MemRefType compatibleMemRefType =
getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
alloc.getType().cast<MemRefType>());
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
@ -2543,7 +2544,7 @@ struct TransferReadExtractPattern
extract.ids()[idCount++] *
std_constant_index(extract.getResultType().getDimSize(pos));
}
Value newRead = vector_transfer_read(extract.getType(), read.memref(),
Value newRead = vector_transfer_read(extract.getType(), read.source(),
indices, read.permutation_map(),
read.padding(), read.maskedAttr());
Value dest = rewriter.create<ConstantOp>(
@ -2579,7 +2580,7 @@ struct TransferWriteInsertPattern
insert.ids()[idCount++] *
std_constant_index(insert.getSourceVectorType().getDimSize(pos));
}
vector_transfer_write(insert.vector(), write.memref(), indices,
vector_transfer_write(insert.vector(), write.source(), indices,
write.permutation_map(), write.maskedAttr());
rewriter.eraseOp(write);
return success();

View File

@ -243,16 +243,16 @@ AffineMap mlir::makePermutationMap(
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
}
AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
memRefType.getElementType().dyn_cast<VectorType>();
shapedType.getElementType().dyn_cast<VectorType>();
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
return AffineMap::getMinorIdentityMap(
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
memRefType.getContext());
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
shapedType.getContext());
}
bool matcher::operatesOnSuperVectorsOf(Operation &op,
@ -314,12 +314,12 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB) {
if (transferA.memref() != transferB.memref())
if (transferA.source() != transferB.source())
return false;
// For simplicity only look at transfer of same type.
if (transferA.getVectorType() != transferB.getVectorType())
return false;
unsigned rankOffset = transferA.getLeadingMemRefRank();
unsigned rankOffset = transferA.getLeadingShapedRank();
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();

View File

@ -269,7 +269,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires memref type}}
// expected-error@+1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@ -297,7 +297,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant 3.0 : f32
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0)->(d0)>} : memref<?x?xf32>, vector<128xf32>
}
@ -343,7 +343,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{requires memref vector element and vector result ranks to match}}
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@ -353,7 +353,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<6xf32>
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@ -392,7 +392,7 @@ func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires memref type}}
// expected-error@+1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@ -419,7 +419,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant dense<3.0> : vector<128 x f32>
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}}
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0)->(d0)>} : vector<128xf32>, memref<?x?xf32>
}

View File

@ -43,6 +43,54 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
return
}
// CHECK-LABEL: func @vector_transfer_ops_tensor(
func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%arg1 : tensor<?x?xvector<4x3xf32>>,
%arg2 : tensor<?x?xvector<4x3xi32>>) ->
(tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>){
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
%c0 = constant 0 : i32
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read
%1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : tensor<?x?xf32>, vector<3x7xf32>
// CHECK: vector.transfer_read
%2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : tensor<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
// CHECK: vector.transfer_write
%7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
// CHECK: vector.transfer_write
%8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
%9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
%10 = vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
%11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
return %7, %8, %9, %10, %11 :
tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>
}
// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>