forked from OSchip/llvm-project
[mlir][Vector] Add an optional "masked" boolean array attribute to vector transfer operations
Summary: Vector transfer ops semantic is extended to allow specifying a per-dimension `masked` attribute. When the attribute is false on a particular dimension, lowering to LLVM emits unmasked load and store operations. Differential Revision: https://reviews.llvm.org/D80098
This commit is contained in:
parent
36cdc17f8c
commit
1870e787af
|
@ -865,7 +865,12 @@ def Vector_ExtractStridedSliceOp :
|
|||
|
||||
def Vector_TransferOpUtils {
|
||||
code extraTransferDeclaration = [{
|
||||
static StringRef getMaskedAttrName() { return "masked"; }
|
||||
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
|
||||
bool isMaskedDim(unsigned dim) {
|
||||
return !masked() ||
|
||||
masked()->cast<ArrayAttr>()[dim].cast<BoolAttr>().getValue();
|
||||
}
|
||||
MemRefType getMemRefType() {
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
|
@ -878,14 +883,15 @@ def Vector_TransferOpUtils {
|
|||
def Vector_TransferReadOp :
|
||||
Vector_Op<"transfer_read">,
|
||||
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map, AnyType:$padding)>,
|
||||
AffineMapAttr:$permutation_map, AnyType:$padding,
|
||||
OptionalAttr<BoolArrayAttr>:$masked)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
|
||||
let summary = "Reads a supervector from memory into an SSA vector value.";
|
||||
|
||||
let description = [{
|
||||
The `vector.transfer_read` op performs a blocking read from a slice within
|
||||
a [MemRef](../LangRef.md#memref-type) supplied as its first operand
|
||||
The `vector.transfer_read` op performs a read from a slice within a
|
||||
[MemRef](../LangRef.md#memref-type) supplied as its first operand
|
||||
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
|
||||
|
||||
A memref operand with vector element type, must have its vector element
|
||||
|
@ -893,8 +899,9 @@ def Vector_TransferReadOp :
|
|||
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
|
||||
|
||||
The slice is further defined by a full-rank index within the MemRef,
|
||||
supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
|
||||
[attribute](../LangRef.md#attributes) is an
|
||||
supplied as the operands `2 .. 1 + rank(memref)`.
|
||||
|
||||
The permutation_map [attribute](../LangRef.md#attributes) is an
|
||||
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
|
||||
slice to match the vector shape. The permutation map may be implicit and
|
||||
ommitted from parsing and printing if it is the canonical minor identity map
|
||||
|
@ -906,6 +913,12 @@ def Vector_TransferReadOp :
|
|||
An `ssa-value` of the same elemental type as the MemRef is provided as the
|
||||
last operand to specify padding in the case of out-of-bounds accesses.
|
||||
|
||||
An optional boolean array attribute is provided to specify which dimensions
|
||||
of the transfer need masking. When a dimension is specified as not requiring
|
||||
masking, the `vector.transfer_read` may be lowered to simple loads. The
|
||||
absence of this `masked` attribute signifies that all dimensions of the
|
||||
transfer need to be masked.
|
||||
|
||||
This operation is called 'read' by opposition to 'load' because the
|
||||
super-vector granularity is generally not representable with a single
|
||||
hardware register. A `vector.transfer_read` is thus a mid-level abstraction
|
||||
|
@ -1015,11 +1028,13 @@ def Vector_TransferReadOp :
|
|||
let builders = [
|
||||
// Builder that sets padding to zero.
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
|
||||
"Value memref, ValueRange indices, AffineMap permutationMap">,
|
||||
"Value memref, ValueRange indices, AffineMap permutationMap, "
|
||||
"ArrayRef<bool> maybeMasked = {}">,
|
||||
// Builder that sets permutation map (resp. padding) to
|
||||
// 'getMinorIdentityMap' (resp. zero).
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
|
||||
"Value memref, ValueRange indices">
|
||||
"Value memref, ValueRange indices, "
|
||||
"ArrayRef<bool> maybeMasked = {}">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
|
||||
|
@ -1039,12 +1054,13 @@ def Vector_TransferWriteOp :
|
|||
Vector_Op<"transfer_write">,
|
||||
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map)> {
|
||||
AffineMapAttr:$permutation_map,
|
||||
OptionalAttr<BoolArrayAttr>:$masked)> {
|
||||
|
||||
let summary = "The vector.transfer_write op writes a supervector to memory.";
|
||||
|
||||
let description = [{
|
||||
The `vector.transfer_write` performs a blocking write from a
|
||||
The `vector.transfer_write` op performs a write from a
|
||||
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
|
||||
slice within a [MemRef](../LangRef.md#memref-type) of the same base
|
||||
elemental type, supplied as its second operand.
|
||||
|
@ -1055,6 +1071,7 @@ def Vector_TransferWriteOp :
|
|||
|
||||
The slice is further defined by a full-rank index within the MemRef,
|
||||
supplied as the operands `3 .. 2 + rank(memref)`.
|
||||
|
||||
The permutation_map [attribute](../LangRef.md#attributes) is an
|
||||
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
|
||||
slice to match the vector shape. The permutation map may be implicit and
|
||||
|
@ -1063,6 +1080,12 @@ def Vector_TransferWriteOp :
|
|||
|
||||
The size of the slice is specified by the size of the vector.
|
||||
|
||||
An optional boolean array attribute is provided to specify which dimensions
|
||||
of the transfer need masking. When a dimension is specified as not requiring
|
||||
masking, the `vector.transfer_write` may be lowered to simple stores. The
|
||||
absence of this `mask` attribute signifies that all dimensions of the
|
||||
transfer need to be masked.
|
||||
|
||||
This operation is called 'write' by opposition to 'store' because the
|
||||
super-vector granularity is generally not representable with a single
|
||||
hardware register. A `vector.transfer_write` is thus a
|
||||
|
@ -1097,7 +1120,10 @@ def Vector_TransferWriteOp :
|
|||
let builders = [
|
||||
// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
|
||||
"Value memref, ValueRange indices">
|
||||
"Value memref, ValueRange indices, "
|
||||
"ArrayRef<bool> maybeMasked = {}">,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
|
||||
"Value memref, ValueRange indices, AffineMap permutationMap">,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
|
||||
|
|
|
@ -746,12 +746,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteOp>
|
||||
LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask);
|
||||
|
||||
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
|
||||
Type type, LLVM::LLVMType &llvmType,
|
||||
unsigned &align) {
|
||||
|
@ -765,12 +759,25 @@ LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult replaceTransferOp<TransferReadOp>(
|
||||
ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
|
||||
Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask) {
|
||||
auto xferOp = cast<TransferReadOp>(op);
|
||||
LogicalResult
|
||||
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
LLVM::LLVMType vecTy;
|
||||
unsigned align;
|
||||
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
|
||||
vecTy, align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
Location loc, TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
|
||||
VectorType fillType = xferOp.getVectorType();
|
||||
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
|
||||
|
@ -783,19 +790,32 @@ LogicalResult replaceTransferOp<TransferReadOp>(
|
|||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
op, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
xferOp, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult replaceTransferOp<TransferWriteOp>(
|
||||
ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
|
||||
Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask) {
|
||||
LogicalResult
|
||||
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
auto adaptor = TransferWriteOpOperandAdaptor(operands);
|
||||
LLVM::LLVMType vecTy;
|
||||
unsigned align;
|
||||
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
|
||||
vecTy, align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto xferOp = cast<TransferWriteOp>(op);
|
||||
LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
Location loc, TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
auto adaptor = TransferWriteOpOperandAdaptor(operands);
|
||||
LLVM::LLVMType vecTy;
|
||||
unsigned align;
|
||||
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
|
||||
|
@ -803,7 +823,8 @@ LogicalResult replaceTransferOp<TransferWriteOp>(
|
|||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
|
||||
xferOp, adaptor.vector(), dataPtr, mask,
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -877,6 +898,10 @@ public:
|
|||
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
|
||||
loc, vecTy.getPointerTo(), dataPtr);
|
||||
|
||||
if (!xferOp.isMaskedDim(0))
|
||||
return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
|
||||
xferOp, operands, vectorDataPtr);
|
||||
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
unsigned vecWidth = vecTy.getVectorNumElements();
|
||||
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
|
||||
|
@ -910,8 +935,8 @@ public:
|
|||
mask);
|
||||
|
||||
// 5. Rewrite as a masked read / write.
|
||||
return replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op,
|
||||
operands, vectorDataPtr, mask);
|
||||
return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
|
||||
operands, vectorDataPtr, mask);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -157,25 +157,34 @@ void NDTransferOpHelper<ConcreteOp>::emitInBounds(
|
|||
ValueRange majorIvs, ValueRange majorOffsets,
|
||||
MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder,
|
||||
LambdaElse elseBlockBuilder) {
|
||||
Value inBounds = std_constant_int(/*value=*/1, /*width=*/1);
|
||||
Value inBounds;
|
||||
SmallVector<Value, 4> majorIvsPlusOffsets;
|
||||
majorIvsPlusOffsets.reserve(majorIvs.size());
|
||||
unsigned idx = 0;
|
||||
for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) {
|
||||
Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
|
||||
using namespace mlir::edsc::op;
|
||||
majorIvsPlusOffsets.push_back(iv + off);
|
||||
Value inBounds2 = majorIvsPlusOffsets.back() < ub;
|
||||
inBounds = inBounds && inBounds2;
|
||||
if (xferOp.isMaskedDim(leadingRank + idx)) {
|
||||
Value inBounds2 = majorIvsPlusOffsets.back() < ub;
|
||||
inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2;
|
||||
}
|
||||
++idx;
|
||||
}
|
||||
|
||||
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), TypeRange{}, inBounds,
|
||||
/*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
|
||||
BlockBuilder(&ifOp.thenRegion().front(),
|
||||
Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
|
||||
if (std::is_same<ConcreteOp, TransferReadOp>())
|
||||
BlockBuilder(&ifOp.elseRegion().front(),
|
||||
Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
|
||||
if (inBounds) {
|
||||
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
|
||||
ScopedContext::getLocation(), TypeRange{}, inBounds,
|
||||
/*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
|
||||
BlockBuilder(&ifOp.thenRegion().front(),
|
||||
Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
|
||||
if (std::is_same<ConcreteOp, TransferReadOp>())
|
||||
BlockBuilder(&ifOp.elseRegion().front(),
|
||||
Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
|
||||
} else {
|
||||
// Just build the body of the then block right here.
|
||||
thenBlockBuilder(majorIvsPlusOffsets);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -192,13 +201,18 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
|
||||
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
|
||||
indexing.append(minorOffsets.begin(), minorOffsets.end());
|
||||
// Lower to 1-D vector_transfer_read and let recursion handle it.
|
||||
|
||||
Value memref = xferOp.memref();
|
||||
auto map = TransferReadOp::getTransferMinorIdentityMap(
|
||||
xferOp.getMemRefType(), minorVectorType);
|
||||
auto loaded1D =
|
||||
vector_transfer_read(minorVectorType, memref, indexing,
|
||||
AffineMapAttr::get(map), xferOp.padding());
|
||||
ArrayAttr masked;
|
||||
if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
|
||||
OpBuilder &b = ScopedContext::getBuilderRef();
|
||||
masked = b.getBoolArrayAttr({true});
|
||||
}
|
||||
auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing,
|
||||
AffineMapAttr::get(map),
|
||||
xferOp.padding(), masked);
|
||||
// Store the 1-D vector.
|
||||
std_store(loaded1D, alloc, majorIvs);
|
||||
};
|
||||
|
@ -229,7 +243,6 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
ValueRange majorOffsets, ValueRange minorOffsets,
|
||||
MemRefBoundsCapture &memrefBounds) {
|
||||
auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
|
||||
// Lower to 1-D vector_transfer_write and let recursion handle it.
|
||||
SmallVector<Value, 8> indexing;
|
||||
indexing.reserve(leadingRank + majorRank + minorRank);
|
||||
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
|
||||
|
@ -239,8 +252,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
Value loaded1D = std_load(alloc, majorIvs);
|
||||
auto map = TransferWriteOp::getTransferMinorIdentityMap(
|
||||
xferOp.getMemRefType(), minorVectorType);
|
||||
ArrayAttr masked;
|
||||
if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
|
||||
OpBuilder &b = ScopedContext::getBuilderRef();
|
||||
masked = b.getBoolArrayAttr({true});
|
||||
}
|
||||
vector_transfer_write(loaded1D, xferOp.memref(), indexing,
|
||||
AffineMapAttr::get(map));
|
||||
AffineMapAttr::get(map), masked);
|
||||
};
|
||||
// Don't write anything when out of bounds.
|
||||
auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {};
|
||||
|
|
|
@ -1017,8 +1017,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
|
|||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
|
||||
LLVM_DEBUG(permutationMap.print(dbgs()));
|
||||
auto transfer = b.create<vector::TransferWriteOp>(
|
||||
opInst->getLoc(), vectorValue, memRef, indices,
|
||||
AffineMapAttr::get(permutationMap));
|
||||
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
|
||||
auto *res = transfer.getOperation();
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
|
||||
// "Terminals" (i.e. AffineStoreOps) are erased on the spot.
|
||||
|
|
|
@ -1202,6 +1202,23 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
// TransferReadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build the default minor identity map suitable for a vector transfer. This
|
||||
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
|
||||
/// rank of the identity map must take the vector element type into account.
|
||||
AffineMap
|
||||
mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
|
||||
VectorType vectorType) {
|
||||
int64_t elementVectorRank = 0;
|
||||
VectorType elementVectorType =
|
||||
memRefType.getElementType().dyn_cast<VectorType>();
|
||||
if (elementVectorType)
|
||||
elementVectorRank += elementVectorType.getRank();
|
||||
return AffineMap::getMinorIdentityMap(
|
||||
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
|
||||
memRefType.getContext());
|
||||
}
|
||||
|
||||
template <typename EmitFun>
|
||||
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
||||
EmitFun emitOpError) {
|
||||
|
@ -1233,7 +1250,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|||
|
||||
static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
|
||||
VectorType vectorType,
|
||||
AffineMap permutationMap) {
|
||||
AffineMap permutationMap,
|
||||
ArrayAttr optionalMasked) {
|
||||
auto memrefElementType = memrefType.getElementType();
|
||||
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
|
@ -1282,52 +1300,60 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
|
|||
return op->emitOpError("requires a permutation_map with input dims of the "
|
||||
"same rank as the memref type");
|
||||
|
||||
if (optionalMasked) {
|
||||
if (permutationMap.getNumResults() !=
|
||||
static_cast<int64_t>(optionalMasked.size()))
|
||||
return op->emitOpError("expects the optional masked attr of same rank as "
|
||||
"permutation_map results: ")
|
||||
<< AffineMapAttr::get(permutationMap);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Build the default minor identity map suitable for a vector transfer. This
|
||||
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
|
||||
/// rank of the identity map must take the vector element type into account.
|
||||
AffineMap
|
||||
mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
|
||||
VectorType vectorType) {
|
||||
int64_t elementVectorRank = 0;
|
||||
VectorType elementVectorType =
|
||||
memRefType.getElementType().dyn_cast<VectorType>();
|
||||
if (elementVectorType)
|
||||
elementVectorRank += elementVectorType.getRank();
|
||||
return AffineMap::getMinorIdentityMap(
|
||||
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
|
||||
memRefType.getContext());
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and
|
||||
/// zero, respectively, by default.
|
||||
/// Builder that sets padding to zero.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vector, Value memref, ValueRange indices,
|
||||
AffineMap permutationMap) {
|
||||
AffineMap permutationMap,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
Type elemType = vector.cast<VectorType>().getElementType();
|
||||
Value padding = builder.create<ConstantOp>(result.location, elemType,
|
||||
builder.getZeroAttr(elemType));
|
||||
build(builder, result, vector, memref, indices, permutationMap, padding);
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, memref, indices, permutationMap,
|
||||
padding, ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vector, memref, indices, permutationMap, padding,
|
||||
maskedArrayAttr);
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
|
||||
/// (resp. zero).
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vectorType, Value memref,
|
||||
ValueRange indices) {
|
||||
build(builder, result, vectorType, memref, indices,
|
||||
getTransferMinorIdentityMap(memref.getType().cast<MemRefType>(),
|
||||
vectorType));
|
||||
ValueRange indices, ArrayRef<bool> maybeMasked) {
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
memref.getType().cast<MemRefType>(), vectorType);
|
||||
build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
|
||||
}
|
||||
|
||||
template <typename TransferOp>
|
||||
void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
|
||||
SmallVector<StringRef, 1> elidedAttrs;
|
||||
SmallVector<StringRef, 2> elidedAttrs;
|
||||
if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap(
|
||||
op.getMemRefType(), op.getVectorType()))
|
||||
elidedAttrs.push_back(op.getPermutationMapAttrName());
|
||||
bool elideMasked = true;
|
||||
if (auto maybeMasked = op.masked()) {
|
||||
for (auto attr : *maybeMasked) {
|
||||
if (!attr.template cast<BoolAttr>().getValue()) {
|
||||
elideMasked = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (elideMasked)
|
||||
elidedAttrs.push_back(op.getMaskedAttrName());
|
||||
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
|
||||
}
|
||||
|
||||
|
@ -1388,7 +1414,8 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
|
||||
permutationMap)))
|
||||
permutationMap,
|
||||
op.masked() ? *op.masked() : ArrayAttr())))
|
||||
return failure();
|
||||
|
||||
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
|
||||
|
@ -1419,11 +1446,24 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
|
||||
/// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value memref, ValueRange indices) {
|
||||
Value vector, Value memref, ValueRange indices,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
memref.getType().cast<MemRefType>(), vectorType);
|
||||
build(builder, result, vector, memref, indices, permMap);
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, memref, indices, permMap,
|
||||
ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value memref, ValueRange indices,
|
||||
AffineMap permutationMap) {
|
||||
build(builder, result, vector, memref, indices,
|
||||
/*maybeMasked=*/ArrayRef<bool>{});
|
||||
}
|
||||
|
||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||
|
@ -1477,7 +1517,8 @@ static LogicalResult verify(TransferWriteOp op) {
|
|||
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
|
||||
permutationMap)))
|
||||
permutationMap,
|
||||
op.masked() ? *op.masked() : ArrayAttr())))
|
||||
return failure();
|
||||
|
||||
return verifyPermutationMap(permutationMap,
|
||||
|
|
|
@ -564,9 +564,12 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
|||
// Get VectorType for slice 'i'.
|
||||
auto sliceVectorType = resultTupleType.getType(index);
|
||||
// Create split TransferReadOp for 'sliceUser'.
|
||||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
|
||||
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
|
||||
xferReadOp.permutation_map(), xferReadOp.padding());
|
||||
xferReadOp.permutation_map(), xferReadOp.padding(),
|
||||
xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType,
|
||||
resultTupleType, sizes, strides, indices, rewriter,
|
||||
|
@ -620,9 +623,12 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
xferWriteOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
|
||||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
|
||||
xferWriteOp.permutation_map());
|
||||
xferWriteOp.permutation_map(),
|
||||
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, resultVectorType,
|
||||
sourceTupleType, sizes, strides, indices, rewriter,
|
||||
|
|
|
@ -918,6 +918,24 @@ func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -
|
|||
// CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] :
|
||||
// CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">
|
||||
|
||||
func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
||||
%f7 = constant 7.0: f32
|
||||
%f = vector.transfer_read %A[%base], %f7 {masked = [false]} :
|
||||
memref<?xf32>, vector<17xf32>
|
||||
return %f: vector<17xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @transfer_read_1d_not_masked
|
||||
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
|
||||
//
|
||||
// 2. Rewrite as a load.
|
||||
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*">
|
||||
|
||||
func @genbool_1d() -> vector<8xi1> {
|
||||
%0 = vector.constant_mask [4] : vector<8xi1>
|
||||
return %0 : vector<8xi1>
|
||||
|
|
|
@ -220,14 +220,12 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
|
|||
// CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32
|
||||
%f7 = constant 7.0: f32
|
||||
|
||||
// CHECK-DAG: %[[cond0:.*]] = constant 1 : i1
|
||||
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
|
||||
// CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
|
||||
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref<?x?xf32>
|
||||
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
|
||||
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
|
||||
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
|
||||
// CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
|
||||
// CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
|
||||
// CHECK: scf.if %[[cond1]] {
|
||||
// CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref<?x?xf32>, vector<15xf32>
|
||||
// CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>>
|
||||
|
@ -253,7 +251,6 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
|
|||
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
|
||||
// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
|
||||
func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
|
||||
// CHECK: %[[cond0:.*]] = constant 1 : i1
|
||||
// CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
|
||||
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
|
||||
// CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
|
||||
|
@ -261,8 +258,7 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
|
|||
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
|
||||
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
|
||||
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
|
||||
// CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
|
||||
// CHECK: scf.if %[[cond1]] {
|
||||
// CHECK: scf.if %[[cmp]] {
|
||||
// CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
|
||||
// CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
|
||||
// CHECK: }
|
||||
|
@ -271,3 +267,26 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
|
|||
vector<17x15xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: transfer_write_progressive_not_masked(
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
|
||||
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
|
||||
// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
|
||||
func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
|
||||
// CHECK-NOT: scf.if
|
||||
// CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
|
||||
// CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
|
||||
// CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
|
||||
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 {
|
||||
// CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
|
||||
// CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} :
|
||||
vector<17x15xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -348,6 +348,16 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
||||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<2x3xf32>
|
||||
// expected-error@+1 {{ expects the optional masked attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {masked = [false], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
|
|
|
@ -22,6 +22,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : memref<?x?xf32>, vector<128xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
|
||||
// CHECK: vector.transfer_write
|
||||
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
|
||||
|
@ -29,6 +31,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref<?x?xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue