forked from OSchip/llvm-project
[mlir[[vector] Extend Transfer read/write ops to support tensor types.
Transfer_ops can now work on both buffers and tensor. Right now, lowering of the tensor case is not supported yet. Differential Revision: https://reviews.llvm.org/D93500
This commit is contained in:
parent
9a93f95fce
commit
26c8f9081b
|
@ -126,7 +126,7 @@ namespace impl {
|
|||
/// Build the default minor identity map suitable for a vector transfer. This
|
||||
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
|
||||
/// rank of the identity map must take the vector element type into account.
|
||||
AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
|
||||
AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
|
||||
VectorType vectorType);
|
||||
} // namespace impl
|
||||
} // end namespace vector
|
||||
|
|
|
@ -1056,7 +1056,7 @@ def Vector_TransferReadOp :
|
|||
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
|
||||
Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map, AnyType:$padding,
|
||||
OptionalAttr<BoolArrayAttr>:$masked)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
|
@ -1065,15 +1065,16 @@ def Vector_TransferReadOp :
|
|||
|
||||
let description = [{
|
||||
The `vector.transfer_read` op performs a read from a slice within a
|
||||
[MemRef](../LangRef.md#memref-type) supplied as its first operand
|
||||
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
|
||||
[MemRef](../LangRef.md#memref-type) or a Ranked
|
||||
[Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
|
||||
[vector](../LangRef.md#vector-type) of the same base elemental type.
|
||||
|
||||
A memref operand with vector element type, must have its vector element
|
||||
type match a suffix (shape and element type) of the vector (e.g.
|
||||
A memref/tensor operand with vector element type, must have its vector
|
||||
element type match a suffix (shape and element type) of the vector (e.g.
|
||||
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
|
||||
|
||||
The slice is further defined by a full-rank index within the MemRef,
|
||||
supplied as the operands `2 .. 1 + rank(memref)`.
|
||||
The slice is further defined by a full-rank index within the MemRef/Tensor,
|
||||
supplied as the operands `2 .. 1 + rank(memref/tensor)`.
|
||||
|
||||
The permutation_map [attribute](../LangRef.md#attributes) is an
|
||||
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
|
||||
|
@ -1084,8 +1085,9 @@ def Vector_TransferReadOp :
|
|||
The size of the slice is specified by the size of the vector, given as the
|
||||
return type.
|
||||
|
||||
An `ssa-value` of the same elemental type as the MemRef is provided as the
|
||||
last operand to specify padding in the case of out-of-bounds accesses.
|
||||
An `ssa-value` of the same elemental type as the MemRef/Tensor is provided
|
||||
as the last operand to specify padding in the case of out-of-bounds
|
||||
accesses.
|
||||
|
||||
An optional boolean array attribute is provided to specify which dimensions
|
||||
of the transfer need masking. When a dimension is specified as not requiring
|
||||
|
@ -1196,17 +1198,22 @@ def Vector_TransferReadOp :
|
|||
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
|
||||
{permutation_map = (d0, d1)->(d0, d1)}
|
||||
: memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
|
||||
// Read from a tensor with vector element type.
|
||||
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
|
||||
{permutation_map = (d0, d1)->(d0, d1)}
|
||||
: tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Builder that sets padding to zero.
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "AffineMap":$permutationMap,
|
||||
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
|
||||
// Builder that sets permutation map (resp. padding) to
|
||||
// 'getMinorIdentityMap' (resp. zero).
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
|
||||
OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>
|
||||
];
|
||||
|
||||
|
@ -1217,26 +1224,29 @@ def Vector_TransferWriteOp :
|
|||
Vector_Op<"transfer_write", [
|
||||
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$vector, AnyShaped:$source,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map,
|
||||
OptionalAttr<BoolArrayAttr>:$masked)> {
|
||||
OptionalAttr<BoolArrayAttr>:$masked)>,
|
||||
Results<(outs Optional<AnyRankedTensor>:$result)> {
|
||||
|
||||
let summary = "The vector.transfer_write op writes a supervector to memory.";
|
||||
|
||||
let description = [{
|
||||
The `vector.transfer_write` op performs a write from a
|
||||
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
|
||||
slice within a [MemRef](../LangRef.md#memref-type) of the same base
|
||||
elemental type, supplied as its second operand.
|
||||
slice within a [MemRef](../LangRef.md#memref-type) or a Ranked
|
||||
[Tensor](../LangRef.md#tensor-type) of the same base elemental type,
|
||||
supplied as its second operand.
|
||||
|
||||
A vector memref operand must have its vector element type match a suffix
|
||||
(shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
|
||||
vector<1x1x4x3xf32>).
|
||||
A vector memref/tensor operand must have its vector element type match a
|
||||
suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
|
||||
vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a
|
||||
new tensor of the same type.
|
||||
|
||||
The slice is further defined by a full-rank index within the MemRef,
|
||||
supplied as the operands `3 .. 2 + rank(memref)`.
|
||||
The slice is further defined by a full-rank index within the MemRef/Tensor,
|
||||
supplied as the operands `3 .. 2 + rank(memref/tensor)`.
|
||||
|
||||
The permutation_map [attribute](../LangRef.md#attributes) is an
|
||||
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
|
||||
|
@ -1280,15 +1290,24 @@ def Vector_TransferWriteOp :
|
|||
vector.transfer_write %4, %arg1[%c3, %c3]
|
||||
{permutation_map = (d0, d1)->(d0, d1)}
|
||||
: vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
|
||||
// return a tensor where the vector is inserted into the source tensor.
|
||||
%5 = vector.transfer_write %4, %arg1[%c3, %c3]
|
||||
{permutation_map = (d0, d1)->(d0, d1)}
|
||||
: vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
|
||||
```
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap)>,
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMapAttr":$permutationMap, "ArrayAttr":$masked)>,
|
||||
OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap, "ArrayAttr":$masked)>,
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
|
|
@ -20,9 +20,9 @@ class AffineApplyOp;
|
|||
class AffineForOp;
|
||||
class AffineMap;
|
||||
class Location;
|
||||
class MemRefType;
|
||||
class OpBuilder;
|
||||
class Operation;
|
||||
class ShapedType;
|
||||
class Value;
|
||||
class VectorType;
|
||||
class VectorTransferOpInterface;
|
||||
|
@ -157,7 +157,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
|
|||
/// Build the default minor identity map suitable for a vector transfer. This
|
||||
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
|
||||
/// rank of the identity map must take the vector element type into account.
|
||||
AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
|
||||
AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
|
||||
VectorType vectorType);
|
||||
|
||||
/// Return true if we can prove that the transfer operations access disjoint
|
||||
|
|
|
@ -47,7 +47,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
|
|||
|
||||
def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
||||
let description = [{
|
||||
Encodes properties of an operation on vectors that can be unrolled.
|
||||
Encodes properties of a transfer read or write operation.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
|
@ -83,11 +83,11 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the memref operand.",
|
||||
/*desc=*/"Return the memref or ranked tensor operand.",
|
||||
/*retTy=*/"Value",
|
||||
/*methodName=*/"memref",
|
||||
/*methodName=*/"source",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"return $_op.memref();"
|
||||
/*methodBody=*/"return $_op.source();"
|
||||
/*defaultImplementation=*/
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -123,13 +123,13 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
/*defaultImplementation=*/
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the MemRefType.",
|
||||
/*retTy=*/"MemRefType",
|
||||
/*methodName=*/"getMemRefType",
|
||||
/*desc=*/"Return the ShapedType.",
|
||||
/*retTy=*/"ShapedType",
|
||||
/*methodName=*/"getShapedType",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/
|
||||
"return $_op.memref().getType().template cast<MemRefType>();"
|
||||
"return $_op.source().getType().template cast<ShapedType>();"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the VectorType.",
|
||||
|
@ -152,14 +152,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
"return $_op.permutation_map().getNumResults();"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Return the number of leading memref dimensions that do not
|
||||
/*desc=*/[{ Return the number of leading shaped dimensions that do not
|
||||
participate in the permutation map.}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getLeadingMemRefRank",
|
||||
/*methodName=*/"getLeadingShapedRank",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/
|
||||
"return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
|
||||
"return $_op.getShapedType().getRank() - $_op.getTransferRank();"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
|
||||
|
@ -178,8 +178,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
/*desc=*/[{
|
||||
Helper function to account for the fact that `permutationMap` results and
|
||||
`op.indices` sizes may not match and may not be aligned. The first
|
||||
`getLeadingMemRefRank()` indices may just be indexed and not transferred
|
||||
from/into the vector.
|
||||
`getLeadingShapedRank()` indices may just be indexed and not
|
||||
transferred from/into the vector.
|
||||
For example:
|
||||
```
|
||||
vector.transfer %0[%i, %j, %k, %c0] :
|
||||
|
@ -195,7 +195,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
for (int64_t resultIdx = 0,
|
||||
indicesIdx = $_op.getLeadingMemRefRank(),
|
||||
indicesIdx = $_op.getLeadingShapedRank(),
|
||||
eResult = $_op.getTransferRank();
|
||||
resultIdx < eResult;
|
||||
++resultIdx, ++indicesIdx)
|
||||
|
|
|
@ -22,6 +22,17 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Helpers to access the memref operand for each op.
|
||||
static Value getMemRefOperand(LoadOp op) { return op.memref(); }
|
||||
|
||||
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
|
||||
|
||||
static Value getMemRefOperand(StoreOp op) { return op.memref(); }
|
||||
|
||||
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
||||
return op.source();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Merges subview operation with load/transferRead operation.
|
||||
template <typename OpTy>
|
||||
|
@ -141,7 +152,7 @@ template <typename OpTy>
|
|||
LogicalResult
|
||||
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
|
||||
auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -162,7 +173,8 @@ template <typename OpTy>
|
|||
LogicalResult
|
||||
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
|
||||
auto subViewOp =
|
||||
getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of an operation with memref.
|
||||
template <typename T>
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
||||
unsigned &align) {
|
||||
Type elementTy =
|
||||
typeConverter.convertType(op.getMemRefType().getElementType());
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
Type elementTy = typeConverter.convertType(memrefType.getElementType());
|
||||
if (!elementTy)
|
||||
return failure();
|
||||
|
||||
|
@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
|||
TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
|
||||
return success();
|
||||
|
@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
|||
return failure();
|
||||
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
|
@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
|||
TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||
|
@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
|||
TransferWriteOp xferOp, ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
|
@ -345,7 +347,8 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter->convertType(load.getResultVectorType());
|
||||
|
@ -375,7 +378,8 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter->convertType(store.getValueVectorType());
|
||||
|
@ -405,7 +409,8 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
|
@ -438,7 +443,8 @@ public:
|
|||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
|
||||
align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
|
@ -1182,8 +1188,11 @@ public:
|
|||
xferOp.getVectorType().getRank(),
|
||||
xferOp->getContext()))
|
||||
return failure();
|
||||
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Only contiguous source tensors supported atm.
|
||||
auto strides = computeContiguousStrides(xferOp.getMemRefType());
|
||||
auto strides = computeContiguousStrides(memRefType);
|
||||
if (!strides)
|
||||
return failure();
|
||||
|
||||
|
@ -1192,10 +1201,9 @@ public:
|
|||
};
|
||||
|
||||
Location loc = xferOp->getLoc();
|
||||
MemRefType memRefType = xferOp.getMemRefType();
|
||||
|
||||
if (auto memrefVectorElementType =
|
||||
memRefType.getElementType().dyn_cast<VectorType>()) {
|
||||
memRefType.getElementType().template dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
if (memrefVectorElementType.getElementType() !=
|
||||
xferOp.getVectorType().getElementType())
|
||||
|
@ -1222,7 +1230,7 @@ public:
|
|||
// address space 0.
|
||||
// TODO: support alignment when possible.
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
auto vecTy =
|
||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
Value vectorDataPtr;
|
||||
|
@ -1248,7 +1256,7 @@ public:
|
|||
unsigned vecWidth = vecTy.getVectorNumElements();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
|
|
|
@ -89,7 +89,9 @@ public:
|
|||
return failure();
|
||||
|
||||
// Obtain dataPtr and elementType from the memref.
|
||||
MemRefType memRefType = xferOp.getMemRefType();
|
||||
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
|
||||
// In case of 3(LDS): fall back to vector->llvm pass
|
||||
// In case of 5(VGPR): wrong
|
||||
|
@ -101,7 +103,7 @@ public:
|
|||
// indices, so no need to calculate offset size in bytes again in
|
||||
// the MUBUF instruction.
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
|
||||
// 1. Create and fill a <4 x i32> dwordConfig with:
|
||||
// 1st two elements holding the address of dataPtr.
|
||||
|
|
|
@ -107,7 +107,7 @@ public:
|
|||
// TODO: when we go to k > 1-D vectors adapt minorRank.
|
||||
minorRank = 1;
|
||||
majorRank = vectorType.getRank() - minorRank;
|
||||
leadingRank = xferOp.getLeadingMemRefRank();
|
||||
leadingRank = xferOp.getLeadingShapedRank();
|
||||
majorVectorType =
|
||||
VectorType::get(vectorType.getShape().take_front(majorRank),
|
||||
vectorType.getElementType());
|
||||
|
@ -115,9 +115,9 @@ public:
|
|||
VectorType::get(vectorType.getShape().take_back(minorRank),
|
||||
vectorType.getElementType());
|
||||
/// Memref of minor vector type is used for individual transfers.
|
||||
memRefMinorVectorType =
|
||||
MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
|
||||
xferOp.getMemRefType().getMemorySpace());
|
||||
memRefMinorVectorType = MemRefType::get(
|
||||
majorVectorType.getShape(), minorVectorType, {},
|
||||
xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
|
||||
}
|
||||
|
||||
LogicalResult doReplace();
|
||||
|
@ -155,7 +155,7 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(
|
|||
const MemRefBoundsCapture &)>
|
||||
loopBodyBuilder) {
|
||||
/// Loop nest operates on the major dimensions
|
||||
MemRefBoundsCapture memrefBoundsCapture(xferOp.memref());
|
||||
MemRefBoundsCapture memrefBoundsCapture(xferOp.source());
|
||||
|
||||
if (options.unroll) {
|
||||
auto shape = majorVectorType.getShape();
|
||||
|
@ -272,9 +272,9 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
|
||||
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
|
||||
indexing.append(minorOffsets.begin(), minorOffsets.end());
|
||||
Value memref = xferOp.memref();
|
||||
Value memref = xferOp.source();
|
||||
auto map =
|
||||
getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
|
||||
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
|
||||
ArrayAttr masked;
|
||||
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
|
||||
OpBuilder &b = ScopedContext::getBuilderRef();
|
||||
|
@ -379,13 +379,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
else
|
||||
result = std_load(alloc, majorIvs);
|
||||
auto map =
|
||||
getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
|
||||
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
|
||||
ArrayAttr masked;
|
||||
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
|
||||
OpBuilder &b = ScopedContext::getBuilderRef();
|
||||
masked = b.getBoolArrayAttr({false});
|
||||
}
|
||||
vector_transfer_write(result, xferOp.memref(), indexing,
|
||||
vector_transfer_write(result, xferOp.source(), indexing,
|
||||
AffineMapAttr::get(map), masked);
|
||||
};
|
||||
|
||||
|
@ -422,7 +422,7 @@ template <typename TransferOpTy>
|
|||
static int computeCoalescedIndex(TransferOpTy transfer) {
|
||||
// rank of the remote memory access, coalescing behavior occurs on the
|
||||
// innermost memory dimension.
|
||||
auto remoteRank = transfer.getMemRefType().getRank();
|
||||
auto remoteRank = transfer.getShapedType().getRank();
|
||||
// Iterate over the results expressions of the permutation map to determine
|
||||
// the loop order for creating pointwise copies between remote and local
|
||||
// memories.
|
||||
|
@ -536,13 +536,14 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
using namespace mlir::edsc::op;
|
||||
|
||||
TransferReadOp transfer = cast<TransferReadOp>(op);
|
||||
|
||||
auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Fall back to a loop if the fastest varying stride is not 1 or it is
|
||||
// permuted.
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides =
|
||||
getStridesAndOffset(transfer.getMemRefType(), strides, offset);
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
if (succeeded(successStrides) && strides.back() == 1 &&
|
||||
transfer.permutation_map().isMinorIdentity()) {
|
||||
// If > 1D, emit a bunch of loops around 1-D vector transfers.
|
||||
|
@ -557,8 +558,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
// Conservative lowering to scalar load / stores.
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(rewriter, transfer.getLoc());
|
||||
StdIndexedValue remote(transfer.memref());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
|
||||
StdIndexedValue remote(transfer.source());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
int coalescedIdx = computeCoalescedIndex(transfer);
|
||||
// Swap the vectorBoundsCapture which will reorder loop bounds.
|
||||
|
@ -621,13 +622,15 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
using namespace edsc::op;
|
||||
|
||||
TransferWriteOp transfer = cast<TransferWriteOp>(op);
|
||||
auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
|
||||
// Fall back to a loop if the fastest varying stride is not 1 or it is
|
||||
// permuted.
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides =
|
||||
getStridesAndOffset(transfer.getMemRefType(), strides, offset);
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
if (succeeded(successStrides) && strides.back() == 1 &&
|
||||
transfer.permutation_map().isMinorIdentity()) {
|
||||
// If > 1D, emit a bunch of loops around 1-D vector transfers.
|
||||
|
@ -641,8 +644,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(rewriter, transfer.getLoc());
|
||||
StdIndexedValue remote(transfer.memref());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
|
||||
StdIndexedValue remote(transfer.source());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
|
||||
Value vectorValue(transfer.vector());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
int coalescedIdx = computeCoalescedIndex(transfer);
|
||||
|
|
|
@ -111,7 +111,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
|
|||
vector::TransferWriteOp transferWrite;
|
||||
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
|
||||
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
|
||||
if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
|
||||
if (!candidateWrite || candidateWrite.source() != transferRead.source())
|
||||
continue;
|
||||
transferWrite = candidateWrite;
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
|
|||
DominanceInfo dom(loop);
|
||||
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
|
||||
return WalkResult::advance();
|
||||
for (auto &use : transferRead.memref().getUses()) {
|
||||
for (auto &use : transferRead.source().getUses()) {
|
||||
if (!dom.properlyDominates(loop, use.getOwner()))
|
||||
continue;
|
||||
if (use.getOwner() == transferRead.getOperation() ||
|
||||
|
|
|
@ -411,7 +411,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
|
||||
|
||||
// Transfer into `view`.
|
||||
Value viewOrAlloc = xferOp.memref();
|
||||
Value viewOrAlloc = xferOp.source();
|
||||
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<AllocOp>())
|
||||
return failure();
|
||||
|
@ -487,7 +487,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
||||
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
|
||||
// Transfer into `viewOrAlloc`.
|
||||
Value viewOrAlloc = xferOp.memref();
|
||||
Value viewOrAlloc = xferOp.source();
|
||||
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<AllocOp>())
|
||||
return failure();
|
||||
|
|
|
@ -1890,41 +1890,43 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
|
||||
static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
||||
VectorType vectorType,
|
||||
AffineMap permutationMap,
|
||||
ArrayAttr optionalMasked) {
|
||||
auto memrefElementType = memrefType.getElementType();
|
||||
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
|
||||
unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
|
||||
memrefVectorElementType.getShape().back();
|
||||
if (!shapedType.isa<MemRefType, RankedTensorType>())
|
||||
return op->emitOpError(
|
||||
"requires source to be a memref or ranked tensor type");
|
||||
auto elementType = shapedType.getElementType();
|
||||
if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
|
||||
// Memref or tensor has vector element type.
|
||||
unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
|
||||
vectorElementType.getShape().back();
|
||||
unsigned resultVecSize =
|
||||
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
|
||||
if (resultVecSize % memrefVecSize != 0)
|
||||
if (resultVecSize % sourceVecSize != 0)
|
||||
return op->emitOpError(
|
||||
"requires the bitwidth of the minor 1-D vector to be an integral "
|
||||
"multiple of the bitwidth of the minor 1-D vector of the memref");
|
||||
"multiple of the bitwidth of the minor 1-D vector of the source");
|
||||
|
||||
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
|
||||
unsigned sourceVecEltRank = vectorElementType.getRank();
|
||||
unsigned resultVecRank = vectorType.getRank();
|
||||
if (memrefVecEltRank > resultVecRank)
|
||||
if (sourceVecEltRank > resultVecRank)
|
||||
return op->emitOpError(
|
||||
"requires memref vector element and vector result ranks to match.");
|
||||
unsigned rankOffset = resultVecRank - memrefVecEltRank;
|
||||
"requires source vector element and vector result ranks to match.");
|
||||
unsigned rankOffset = resultVecRank - sourceVecEltRank;
|
||||
// Check that permutation map results match 'rankOffset' of vector type.
|
||||
if (permutationMap.getNumResults() != rankOffset)
|
||||
return op->emitOpError("requires a permutation_map with result dims of "
|
||||
"the same rank as the vector type");
|
||||
} else {
|
||||
// Memref has scalar element type.
|
||||
// Memref or tensor has scalar element type.
|
||||
unsigned resultVecSize =
|
||||
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
|
||||
if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
|
||||
if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
|
||||
return op->emitOpError(
|
||||
"requires the bitwidth of the minor 1-D vector to be an integral "
|
||||
"multiple of the bitwidth of the memref element type");
|
||||
"multiple of the bitwidth of the source element type");
|
||||
|
||||
// Check that permutation map results match rank of vector type.
|
||||
if (permutationMap.getNumResults() != vectorType.getRank())
|
||||
|
@ -1934,9 +1936,9 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
|
|||
|
||||
if (permutationMap.getNumSymbols() != 0)
|
||||
return op->emitOpError("requires permutation_map without symbols");
|
||||
if (permutationMap.getNumInputs() != memrefType.getRank())
|
||||
if (permutationMap.getNumInputs() != shapedType.getRank())
|
||||
return op->emitOpError("requires a permutation_map with input dims of the "
|
||||
"same rank as the memref type");
|
||||
"same rank as the source type");
|
||||
|
||||
if (optionalMasked) {
|
||||
if (permutationMap.getNumResults() !=
|
||||
|
@ -1978,7 +1980,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
|||
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
||||
SmallVector<StringRef, 2> elidedAttrs;
|
||||
if (op.permutation_map() ==
|
||||
getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
|
||||
getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
|
||||
elidedAttrs.push_back(op.getPermutationMapAttrName());
|
||||
bool elideMasked = true;
|
||||
if (auto maybeMasked = op.masked()) {
|
||||
|
@ -1995,21 +1997,21 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TransferReadOp op) {
|
||||
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
|
||||
p << op.getOperationName() << " " << op.source() << "[" << op.indices()
|
||||
<< "], " << op.padding();
|
||||
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
|
||||
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
|
||||
p << " : " << op.getShapedType() << ", " << op.getVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc typesLoc;
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
OpAsmParser::OperandType sourceInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
||||
OpAsmParser::OperandType paddingInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
// Parsing with support for paddingValue.
|
||||
if (parser.parseOperand(memrefInfo) ||
|
||||
if (parser.parseOperand(sourceInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseComma() || parser.parseOperand(paddingInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
|
@ -2018,48 +2020,48 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
|||
if (types.size() != 2)
|
||||
return parser.emitError(typesLoc, "requires two types");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
MemRefType memRefType = types[0].dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return parser.emitError(typesLoc, "requires memref type");
|
||||
auto shapedType = types[0].dyn_cast<ShapedType>();
|
||||
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
|
||||
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
|
||||
VectorType vectorType = types[1].dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return parser.emitError(typesLoc, "requires vector type");
|
||||
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
|
||||
auto attr = result.attributes.get(permutationAttrName);
|
||||
if (!attr) {
|
||||
auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
|
||||
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
|
||||
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
|
||||
}
|
||||
return failure(
|
||||
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
|
||||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
||||
parser.resolveOperand(paddingInfo, memRefType.getElementType(),
|
||||
parser.resolveOperand(paddingInfo, shapedType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.addTypeToList(vectorType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransferReadOp op) {
|
||||
// Consistency of elemental types in memref and vector.
|
||||
MemRefType memrefType = op.getMemRefType();
|
||||
// Consistency of elemental types in source and vector.
|
||||
ShapedType shapedType = op.getShapedType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
auto paddingType = op.padding().getType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
auto memrefElementType = memrefType.getElementType();
|
||||
auto sourceElementType = shapedType.getElementType();
|
||||
|
||||
if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
|
||||
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
||||
if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
|
||||
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
|
||||
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
|
||||
permutationMap,
|
||||
op.masked() ? *op.masked() : ArrayAttr())))
|
||||
return failure();
|
||||
|
||||
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
// Check that 'memrefVectorElementType' and 'paddingType' types match.
|
||||
if (memrefVectorElementType != paddingType)
|
||||
if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
|
||||
// Source has vector element type.
|
||||
// Check that 'sourceVectorElementType' and 'paddingType' types match.
|
||||
if (sourceVectorElementType != paddingType)
|
||||
return op.emitOpError(
|
||||
"requires memref element type and padding type to match.");
|
||||
"requires source element type and padding type to match.");
|
||||
|
||||
} else {
|
||||
// Check that 'paddingType' is valid to store in a vector type.
|
||||
|
@ -2067,9 +2069,9 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
return op.emitOpError("requires valid padding vector elemental type");
|
||||
|
||||
// Check that padding type and vector element types match.
|
||||
if (paddingType != memrefElementType)
|
||||
if (paddingType != sourceElementType)
|
||||
return op.emitOpError(
|
||||
"requires formal padding and memref of the same elemental type");
|
||||
"requires formal padding and source of the same elemental type");
|
||||
}
|
||||
|
||||
return verifyPermutationMap(permutationMap,
|
||||
|
@ -2096,18 +2098,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
template <typename TransferOp>
|
||||
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
|
||||
// TODO: support more aggressive createOrFold on:
|
||||
// `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
|
||||
if (op.getMemRefType().isDynamicDim(indicesIdx))
|
||||
// `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
|
||||
if (op.getShapedType().isDynamicDim(indicesIdx))
|
||||
return false;
|
||||
Value index = op.indices()[indicesIdx];
|
||||
auto cstOp = index.getDefiningOp<ConstantIndexOp>();
|
||||
if (!cstOp)
|
||||
return false;
|
||||
|
||||
int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
|
||||
int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
|
||||
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
|
||||
|
||||
return cstOp.getValue() + vectorSize <= memrefSize;
|
||||
return cstOp.getValue() + vectorSize <= sourceSize;
|
||||
}
|
||||
|
||||
template <typename TransferOp>
|
||||
|
@ -2159,33 +2161,51 @@ Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
|
|||
|
||||
/// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value memref, ValueRange indices,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
memref.getType().cast<MemRefType>(), vectorType);
|
||||
source.getType().cast<MemRefType>(), vectorType);
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, memref, indices, permMap,
|
||||
return build(builder, result, vector, source, indices, permMap,
|
||||
ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
|
||||
build(builder, result, vector, source, indices, permMap, maskedArrayAttr);
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value memref, ValueRange indices,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap) {
|
||||
build(builder, result, vector, memref, indices, permutationMap,
|
||||
build(builder, result, vector, source, indices, permutationMap,
|
||||
/*maybeMasked=*/ArrayAttr());
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMapAttr permutationMap,
|
||||
/*optional*/ ArrayAttr masked) {
|
||||
Type resultType = source.getType().dyn_cast<RankedTensorType>();
|
||||
build(builder, result, resultType, vector, source, indices, permutationMap,
|
||||
masked);
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap,
|
||||
/*optional*/ ArrayAttr masked) {
|
||||
Type resultType = source.getType().dyn_cast<RankedTensorType>();
|
||||
build(builder, result, resultType, vector, source, indices, permutationMap,
|
||||
masked);
|
||||
}
|
||||
|
||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc typesLoc;
|
||||
OpAsmParser::OperandType vectorInfo, memrefInfo;
|
||||
OpAsmParser::OperandType vectorInfo, sourceInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(memrefInfo) ||
|
||||
parser.parseOperand(sourceInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
||||
|
@ -2196,38 +2216,40 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
|||
VectorType vectorType = types[0].dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return parser.emitError(typesLoc, "requires vector type");
|
||||
MemRefType memRefType = types[1].dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return parser.emitError(typesLoc, "requires memref type");
|
||||
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
|
||||
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
|
||||
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
|
||||
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
|
||||
auto attr = result.attributes.get(permutationAttrName);
|
||||
if (!attr) {
|
||||
auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
|
||||
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
|
||||
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
|
||||
}
|
||||
return failure(
|
||||
parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
|
||||
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands));
|
||||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
||||
(shapedType.isa<RankedTensorType>() &&
|
||||
parser.addTypeToList(shapedType, result.types)));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
|
||||
p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
|
||||
<< op.indices() << "]";
|
||||
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
|
||||
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
||||
p << " : " << op.getVectorType() << ", " << op.getShapedType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransferWriteOp op) {
|
||||
// Consistency of elemental types in memref and vector.
|
||||
MemRefType memrefType = op.getMemRefType();
|
||||
ShapedType shapedType = op.getShapedType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
|
||||
if (llvm::size(op.indices()) != memrefType.getRank())
|
||||
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
||||
if (llvm::size(op.indices()) != shapedType.getRank())
|
||||
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
|
||||
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
|
||||
permutationMap,
|
||||
op.masked() ? *op.masked() : ArrayAttr())))
|
||||
return failure();
|
||||
|
|
|
@ -94,7 +94,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
|||
<< "\n");
|
||||
llvm::SmallVector<Operation *, 8> reads;
|
||||
Operation *firstOverwriteCandidate = nullptr;
|
||||
for (auto *user : write.memref().getUsers()) {
|
||||
for (auto *user : write.source().getUsers()) {
|
||||
if (user == write.getOperation())
|
||||
continue;
|
||||
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||
|
@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
|||
<< "\n");
|
||||
SmallVector<Operation *, 8> blockingWrites;
|
||||
vector::TransferWriteOp lastwrite = nullptr;
|
||||
for (Operation *user : read.memref().getUsers()) {
|
||||
for (Operation *user : read.source().getUsers()) {
|
||||
if (isa<vector::TransferReadOp>(user))
|
||||
continue;
|
||||
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||
|
|
|
@ -597,7 +597,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
|
||||
Location loc = readOp.getLoc();
|
||||
auto memrefElementType =
|
||||
readOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
readOp.source().getType().cast<MemRefType>().getElementType();
|
||||
auto tupleType = generateExtractSlicesOpResultType(
|
||||
sourceVectorType, targetShape, strides, builder);
|
||||
int64_t numSlices = tupleType.size();
|
||||
|
@ -612,7 +612,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
|
||||
loc, sliceVectorType, readOp.memref(), sliceIndices,
|
||||
loc, sliceVectorType, readOp.source(), sliceIndices,
|
||||
readOp.permutation_map(), readOp.padding(),
|
||||
readOp.masked() ? *readOp.masked() : ArrayAttr());
|
||||
};
|
||||
|
@ -644,14 +644,14 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
|||
Value tuple = builder.create<vector::ExtractSlicesOp>(
|
||||
loc, tupleType, writeOp.vector(), targetShape, strides);
|
||||
auto memrefElementType =
|
||||
writeOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
writeOp.source().getType().cast<MemRefType>().getElementType();
|
||||
SmallVector<Value, 4> indices(writeOp.indices().begin(),
|
||||
writeOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
auto element = builder.create<vector::TupleGetOp>(
|
||||
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
|
||||
builder.create<vector::TransferWriteOp>(
|
||||
loc, element.getResult(), writeOp.memref(), sliceIndices,
|
||||
loc, element.getResult(), writeOp.source(), sliceIndices,
|
||||
writeOp.permutation_map(),
|
||||
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
|
||||
};
|
||||
|
@ -760,7 +760,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
|
||||
Location loc = xferWriteOp.getLoc();
|
||||
auto memrefElementType =
|
||||
xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
xferWriteOp.source().getType().cast<MemRefType>().getElementType();
|
||||
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
|
||||
xferWriteOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
|
@ -768,7 +768,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
|
||||
loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
|
||||
xferWriteOp.permutation_map(),
|
||||
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
|
||||
};
|
||||
|
@ -2142,7 +2142,7 @@ static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
|
|||
// Fold or create the check that `index + vector_size` <= `memref_size`.
|
||||
Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
|
||||
Value cond =
|
||||
createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
|
||||
createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
|
||||
if (!cond)
|
||||
return;
|
||||
// Conjunction over all dims for which we are in-bounds.
|
||||
|
@ -2207,23 +2207,23 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
|||
}
|
||||
|
||||
/// Operates under a scoped context to build the intersection between the
|
||||
/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
|
||||
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
|
||||
// TODO: view intersection/union/differences should be a proper std op.
|
||||
static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
|
||||
Value alloc) {
|
||||
using namespace edsc::intrinsics;
|
||||
int64_t memrefRank = xferOp.getMemRefType().getRank();
|
||||
int64_t memrefRank = xferOp.getShapedType().getRank();
|
||||
// TODO: relax this precondition, will require rank-reducing subviews.
|
||||
assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
|
||||
"Expected memref rank to match the alloc rank");
|
||||
Value one = std_constant_index(1);
|
||||
ValueRange leadingIndices =
|
||||
xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
|
||||
xferOp.indices().take_front(xferOp.getLeadingShapedRank());
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.append(leadingIndices.begin(), leadingIndices.end());
|
||||
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
|
||||
Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
|
||||
Value dimAlloc = std_dim(alloc, resultIdx);
|
||||
Value index = xferOp.indices()[indicesIdx];
|
||||
AffineExpr i, j, k;
|
||||
|
@ -2235,7 +2235,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
|
|||
ValueRange{dimMemRef, index, dimAlloc});
|
||||
sizes.push_back(affineMin);
|
||||
});
|
||||
return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
|
||||
return std_sub_view(xferOp.source(), xferOp.indices(), sizes,
|
||||
SmallVector<Value, 4>(memrefRank, one));
|
||||
}
|
||||
|
||||
|
@ -2263,12 +2263,12 @@ static scf::IfOp createScopedFullPartialLinalgCopy(
|
|||
using namespace edsc::intrinsics;
|
||||
scf::IfOp fullPartialIfOp;
|
||||
Value zero = std_constant_index(0);
|
||||
Value memref = xferOp.memref();
|
||||
Value memref = xferOp.source();
|
||||
conditionBuilder(
|
||||
returnTypes, inBoundsCond,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value res = memref;
|
||||
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||
if (compatibleMemRefType != xferOp.getShapedType())
|
||||
res = std_memref_cast(memref, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{res};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||
|
@ -2317,12 +2317,12 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
|
|||
using namespace edsc::intrinsics;
|
||||
scf::IfOp fullPartialIfOp;
|
||||
Value zero = std_constant_index(0);
|
||||
Value memref = xferOp.memref();
|
||||
Value memref = xferOp.source();
|
||||
conditionBuilder(
|
||||
returnTypes, inBoundsCond,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value res = memref;
|
||||
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||
if (compatibleMemRefType != xferOp.getShapedType())
|
||||
res = std_memref_cast(memref, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{res};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||
|
@ -2376,7 +2376,7 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
|
|||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||
/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
|
||||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
|
@ -2404,8 +2404,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
|||
return failure();
|
||||
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
if (xferOp.memref().getDefiningOp())
|
||||
b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
|
||||
if (Operation *sourceOp = xferOp.source().getDefiningOp())
|
||||
b.setInsertionPointAfter(sourceOp);
|
||||
else
|
||||
b.setInsertionPoint(xferOp);
|
||||
ScopedContext scope(b, xferOp.getLoc());
|
||||
|
@ -2426,8 +2426,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
|||
b.getI64IntegerAttr(32));
|
||||
}
|
||||
|
||||
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
|
||||
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
|
||||
MemRefType compatibleMemRefType =
|
||||
getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
|
||||
alloc.getType().cast<MemRefType>());
|
||||
|
||||
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
|
||||
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
|
||||
|
@ -2543,7 +2544,7 @@ struct TransferReadExtractPattern
|
|||
extract.ids()[idCount++] *
|
||||
std_constant_index(extract.getResultType().getDimSize(pos));
|
||||
}
|
||||
Value newRead = vector_transfer_read(extract.getType(), read.memref(),
|
||||
Value newRead = vector_transfer_read(extract.getType(), read.source(),
|
||||
indices, read.permutation_map(),
|
||||
read.padding(), read.maskedAttr());
|
||||
Value dest = rewriter.create<ConstantOp>(
|
||||
|
@ -2579,7 +2580,7 @@ struct TransferWriteInsertPattern
|
|||
insert.ids()[idCount++] *
|
||||
std_constant_index(insert.getSourceVectorType().getDimSize(pos));
|
||||
}
|
||||
vector_transfer_write(insert.vector(), write.memref(), indices,
|
||||
vector_transfer_write(insert.vector(), write.source(), indices,
|
||||
write.permutation_map(), write.maskedAttr());
|
||||
rewriter.eraseOp(write);
|
||||
return success();
|
||||
|
|
|
@ -243,16 +243,16 @@ AffineMap mlir::makePermutationMap(
|
|||
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
|
||||
}
|
||||
|
||||
AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
|
||||
AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
|
||||
VectorType vectorType) {
|
||||
int64_t elementVectorRank = 0;
|
||||
VectorType elementVectorType =
|
||||
memRefType.getElementType().dyn_cast<VectorType>();
|
||||
shapedType.getElementType().dyn_cast<VectorType>();
|
||||
if (elementVectorType)
|
||||
elementVectorRank += elementVectorType.getRank();
|
||||
return AffineMap::getMinorIdentityMap(
|
||||
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
|
||||
memRefType.getContext());
|
||||
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
|
||||
shapedType.getContext());
|
||||
}
|
||||
|
||||
bool matcher::operatesOnSuperVectorsOf(Operation &op,
|
||||
|
@ -314,12 +314,12 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
|
|||
|
||||
bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
|
||||
VectorTransferOpInterface transferB) {
|
||||
if (transferA.memref() != transferB.memref())
|
||||
if (transferA.source() != transferB.source())
|
||||
return false;
|
||||
// For simplicity only look at transfer of same type.
|
||||
if (transferA.getVectorType() != transferB.getVectorType())
|
||||
return false;
|
||||
unsigned rankOffset = transferA.getLeadingMemRefRank();
|
||||
unsigned rankOffset = transferA.getLeadingShapedRank();
|
||||
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
|
||||
auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
|
||||
auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
|
||||
|
|
|
@ -269,7 +269,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
|
|||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires memref type}}
|
||||
// expected-error@+1 {{ requires memref or ranked tensor type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
|
||||
}
|
||||
|
||||
|
@ -297,7 +297,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
|
|||
func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
|
||||
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0)->(d0)>} : memref<?x?xf32>, vector<128xf32>
|
||||
}
|
||||
|
||||
|
@ -343,7 +343,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
|
|||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
// expected-error@+1 {{requires memref vector element and vector result ranks to match}}
|
||||
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
|
||||
}
|
||||
|
||||
|
@ -353,7 +353,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
|
|||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<6xf32>
|
||||
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
|
||||
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
|
||||
}
|
||||
|
||||
|
@ -392,7 +392,7 @@ func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
|
|||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires memref type}}
|
||||
// expected-error@+1 {{ requires memref or ranked tensor type}}
|
||||
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
|
||||
}
|
||||
|
||||
|
@ -419,7 +419,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
|||
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant dense<3.0> : vector<128 x f32>
|
||||
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
|
||||
// expected-error@+1 {{requires a permutation_map with input dims of the same rank as the source type}}
|
||||
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0)->(d0)>} : vector<128xf32>, memref<?x?xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,54 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @vector_transfer_ops_tensor(
|
||||
func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
|
||||
%arg1 : tensor<?x?xvector<4x3xf32>>,
|
||||
%arg2 : tensor<?x?xvector<4x3xi32>>) ->
|
||||
(tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
|
||||
tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>){
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
%f0 = constant 0.0 : f32
|
||||
%c0 = constant 0 : i32
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
%v0 = splat %c0 : vector<4x3xi32>
|
||||
|
||||
//
|
||||
// CHECK: vector.transfer_read
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
|
||||
// CHECK: vector.transfer_read
|
||||
%1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : tensor<?x?xf32>, vector<3x7xf32>
|
||||
// CHECK: vector.transfer_read
|
||||
%2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
|
||||
// CHECK: vector.transfer_read
|
||||
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : tensor<?x?xf32>, vector<128xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
|
||||
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
|
||||
|
||||
|
||||
// CHECK: vector.transfer_write
|
||||
%7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
|
||||
// CHECK: vector.transfer_write
|
||||
%8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
|
||||
%9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
|
||||
%10 = vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
|
||||
%11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
|
||||
|
||||
return %7, %8, %9, %10, %11 :
|
||||
tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
|
||||
tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_broadcast
|
||||
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
|
||||
|
|
Loading…
Reference in New Issue