[mlir][Vector] Add an optional "masked" boolean array attribute to vector transfer operations

Summary:
Vector transfer ops semantic is extended to allow specifying a per-dimension `masked`
attribute. When the attribute is false on a particular dimension, lowering to LLVM emits
unmasked load and store operations.

Differential Revision: https://reviews.llvm.org/D80098
This commit is contained in:
Nicolas Vasilache 2020-05-18 11:51:56 -04:00
parent 36cdc17f8c
commit 1870e787af
10 changed files with 255 additions and 89 deletions

View File

@ -865,7 +865,12 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferOpUtils {
code extraTransferDeclaration = [{
static StringRef getMaskedAttrName() { return "masked"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
bool isMaskedDim(unsigned dim) {
return !masked() ||
masked()->cast<ArrayAttr>()[dim].cast<BoolAttr>().getValue();
}
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
@ -878,14 +883,15 @@ def Vector_TransferOpUtils {
def Vector_TransferReadOp :
Vector_Op<"transfer_read">,
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding)>,
AffineMapAttr:$permutation_map, AnyType:$padding,
OptionalAttr<BoolArrayAttr>:$masked)>,
Results<(outs AnyVector:$vector)> {
let summary = "Reads a supervector from memory into an SSA vector value.";
let description = [{
The `vector.transfer_read` op performs a blocking read from a slice within
a [MemRef](../LangRef.md#memref-type) supplied as its first operand
The `vector.transfer_read` op performs a read from a slice within a
[MemRef](../LangRef.md#memref-type) supplied as its first operand
into a [vector](../LangRef.md#vector-type) of the same base elemental type.
A memref operand with vector element type, must have its vector element
@ -893,8 +899,9 @@ def Vector_TransferReadOp :
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
[attribute](../LangRef.md#attributes) is an
supplied as the operands `2 .. 1 + rank(memref)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The permutation map may be implicit and
ommitted from parsing and printing if it is the canonical minor identity map
@ -906,6 +913,12 @@ def Vector_TransferReadOp :
An `ssa-value` of the same elemental type as the MemRef is provided as the
last operand to specify padding in the case of out-of-bounds accesses.
An optional boolean array attribute is provided to specify which dimensions
of the transfer need masking. When a dimension is specified as not requiring
masking, the `vector.transfer_read` may be lowered to simple loads. The
absence of this `masked` attribute signifies that all dimensions of the
transfer need to be masked.
This operation is called 'read' by opposition to 'load' because the
super-vector granularity is generally not representable with a single
hardware register. A `vector.transfer_read` is thus a mid-level abstraction
@ -1015,11 +1028,13 @@ def Vector_TransferReadOp :
let builders = [
// Builder that sets padding to zero.
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
"Value memref, ValueRange indices, AffineMap permutationMap">,
"Value memref, ValueRange indices, AffineMap permutationMap, "
"ArrayRef<bool> maybeMasked = {}">,
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, "
"Value memref, ValueRange indices">
"Value memref, ValueRange indices, "
"ArrayRef<bool> maybeMasked = {}">
];
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
@ -1039,12 +1054,13 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write">,
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map)> {
AffineMapAttr:$permutation_map,
OptionalAttr<BoolArrayAttr>:$masked)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
let description = [{
The `vector.transfer_write` performs a blocking write from a
The `vector.transfer_write` op performs a write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
slice within a [MemRef](../LangRef.md#memref-type) of the same base
elemental type, supplied as its second operand.
@ -1055,6 +1071,7 @@ def Vector_TransferWriteOp :
The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `3 .. 2 + rank(memref)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The permutation map may be implicit and
@ -1063,6 +1080,12 @@ def Vector_TransferWriteOp :
The size of the slice is specified by the size of the vector.
An optional boolean array attribute is provided to specify which dimensions
of the transfer need masking. When a dimension is specified as not requiring
masking, the `vector.transfer_write` may be lowered to simple stores. The
absence of this `mask` attribute signifies that all dimensions of the
transfer need to be masked.
This operation is called 'write' by opposition to 'store' because the
super-vector granularity is generally not representable with a single
hardware register. A `vector.transfer_write` is thus a
@ -1097,7 +1120,10 @@ def Vector_TransferWriteOp :
let builders = [
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
"Value memref, ValueRange indices">
"Value memref, ValueRange indices, "
"ArrayRef<bool> maybeMasked = {}">,
OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, "
"Value memref, ValueRange indices, AffineMap permutationMap">,
];
let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #

View File

@ -746,12 +746,6 @@ public:
}
};
template <typename ConcreteOp>
LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Operation *op, ArrayRef<Value> operands,
Value dataPtr, Value mask);
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
Type type, LLVM::LLVMType &llvmType,
unsigned &align) {
@ -765,12 +759,25 @@ LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
return success();
}
template <>
LogicalResult replaceTransferOp<TransferReadOp>(
ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto xferOp = cast<TransferReadOp>(op);
LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
LLVM::LLVMType vecTy;
unsigned align;
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
vecTy, align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
return success();
}
LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, TransferReadOp xferOp,
ArrayRef<Value> operands,
Value dataPtr, Value mask) {
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
VectorType fillType = xferOp.getVectorType();
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
@ -783,19 +790,32 @@ LogicalResult replaceTransferOp<TransferReadOp>(
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
op, vecTy, dataPtr, mask, ValueRange{fill},
xferOp, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(align));
return success();
}
template <>
LogicalResult replaceTransferOp<TransferWriteOp>(
ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter,
Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr,
Value mask) {
LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
LLVM::LLVMType vecTy;
unsigned align;
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
vecTy, align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
return success();
}
auto xferOp = cast<TransferWriteOp>(op);
LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, TransferWriteOp xferOp,
ArrayRef<Value> operands,
Value dataPtr, Value mask) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
LLVM::LLVMType vecTy;
unsigned align;
if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
@ -803,7 +823,8 @@ LogicalResult replaceTransferOp<TransferWriteOp>(
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
return success();
}
@ -877,6 +898,10 @@ public:
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc, vecTy.getPointerTo(), dataPtr);
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
unsigned vecWidth = vecTy.getVectorNumElements();
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
@ -910,8 +935,8 @@ public:
mask);
// 5. Rewrite as a masked read / write.
return replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op,
operands, vectorDataPtr, mask);
return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
operands, vectorDataPtr, mask);
}
};

View File

@ -157,25 +157,34 @@ void NDTransferOpHelper<ConcreteOp>::emitInBounds(
ValueRange majorIvs, ValueRange majorOffsets,
MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder,
LambdaElse elseBlockBuilder) {
Value inBounds = std_constant_int(/*value=*/1, /*width=*/1);
Value inBounds;
SmallVector<Value, 4> majorIvsPlusOffsets;
majorIvsPlusOffsets.reserve(majorIvs.size());
unsigned idx = 0;
for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) {
Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
Value inBounds2 = majorIvsPlusOffsets.back() < ub;
inBounds = inBounds && inBounds2;
if (xferOp.isMaskedDim(leadingRank + idx)) {
Value inBounds2 = majorIvsPlusOffsets.back() < ub;
inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2;
}
++idx;
}
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), TypeRange{}, inBounds,
/*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
BlockBuilder(&ifOp.thenRegion().front(),
Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
if (std::is_same<ConcreteOp, TransferReadOp>())
BlockBuilder(&ifOp.elseRegion().front(),
Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
if (inBounds) {
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), TypeRange{}, inBounds,
/*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
BlockBuilder(&ifOp.thenRegion().front(),
Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
if (std::is_same<ConcreteOp, TransferReadOp>())
BlockBuilder(&ifOp.elseRegion().front(),
Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
} else {
// Just build the body of the then block right here.
thenBlockBuilder(majorIvsPlusOffsets);
}
}
template <>
@ -192,13 +201,18 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
// Lower to 1-D vector_transfer_read and let recursion handle it.
Value memref = xferOp.memref();
auto map = TransferReadOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
auto loaded1D =
vector_transfer_read(minorVectorType, memref, indexing,
AffineMapAttr::get(map), xferOp.padding());
ArrayAttr masked;
if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({true});
}
auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing,
AffineMapAttr::get(map),
xferOp.padding(), masked);
// Store the 1-D vector.
std_store(loaded1D, alloc, majorIvs);
};
@ -229,7 +243,6 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
ValueRange majorOffsets, ValueRange minorOffsets,
MemRefBoundsCapture &memrefBounds) {
auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
// Lower to 1-D vector_transfer_write and let recursion handle it.
SmallVector<Value, 8> indexing;
indexing.reserve(leadingRank + majorRank + minorRank);
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
@ -239,8 +252,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
Value loaded1D = std_load(alloc, majorIvs);
auto map = TransferWriteOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
ArrayAttr masked;
if (xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({true});
}
vector_transfer_write(loaded1D, xferOp.memref(), indexing,
AffineMapAttr::get(map));
AffineMapAttr::get(map), masked);
};
// Don't write anything when out of bounds.
auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {};

View File

@ -1017,8 +1017,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<vector::TransferWriteOp>(
opInst->getLoc(), vectorValue, memRef, indices,
AffineMapAttr::get(permutationMap));
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = transfer.getOperation();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminals" (i.e. AffineStoreOps) are erased on the spot.

View File

@ -1202,6 +1202,23 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
// TransferReadOp
//===----------------------------------------------------------------------===//
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
AffineMap
mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
memRefType.getElementType().dyn_cast<VectorType>();
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
return AffineMap::getMinorIdentityMap(
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
memRefType.getContext());
}
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
@ -1233,7 +1250,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
VectorType vectorType,
AffineMap permutationMap) {
AffineMap permutationMap,
ArrayAttr optionalMasked) {
auto memrefElementType = memrefType.getElementType();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
@ -1282,52 +1300,60 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
if (optionalMasked) {
if (permutationMap.getNumResults() !=
static_cast<int64_t>(optionalMasked.size()))
return op->emitOpError("expects the optional masked attr of same rank as "
"permutation_map results: ")
<< AffineMapAttr::get(permutationMap);
}
return success();
}
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
AffineMap
mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
memRefType.getElementType().dyn_cast<VectorType>();
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
return AffineMap::getMinorIdentityMap(
memRefType.getRank(), vectorType.getRank() - elementVectorRank,
memRefType.getContext());
}
/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and
/// zero, respectively, by default.
/// Builder that sets padding to zero.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vector, Value memref, ValueRange indices,
AffineMap permutationMap) {
AffineMap permutationMap,
ArrayRef<bool> maybeMasked) {
Type elemType = vector.cast<VectorType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
build(builder, result, vector, memref, indices, permutationMap, padding);
if (maybeMasked.empty())
return build(builder, result, vector, memref, indices, permutationMap,
padding, ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
build(builder, result, vector, memref, indices, permutationMap, padding,
maskedArrayAttr);
}
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
/// (resp. zero).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value memref,
ValueRange indices) {
build(builder, result, vectorType, memref, indices,
getTransferMinorIdentityMap(memref.getType().cast<MemRefType>(),
vectorType));
ValueRange indices, ArrayRef<bool> maybeMasked) {
auto permMap = getTransferMinorIdentityMap(
memref.getType().cast<MemRefType>(), vectorType);
build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
}
template <typename TransferOp>
void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
SmallVector<StringRef, 1> elidedAttrs;
SmallVector<StringRef, 2> elidedAttrs;
if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap(
op.getMemRefType(), op.getVectorType()))
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideMasked = true;
if (auto maybeMasked = op.masked()) {
for (auto attr : *maybeMasked) {
if (!attr.template cast<BoolAttr>().getValue()) {
elideMasked = false;
break;
}
}
}
if (elideMasked)
elidedAttrs.push_back(op.getMaskedAttrName());
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
}
@ -1388,7 +1414,8 @@ static LogicalResult verify(TransferReadOp op) {
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
permutationMap)))
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
@ -1419,11 +1446,24 @@ static LogicalResult verify(TransferReadOp op) {
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value memref, ValueRange indices) {
Value vector, Value memref, ValueRange indices,
ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
memref.getType().cast<MemRefType>(), vectorType);
build(builder, result, vector, memref, indices, permMap);
if (maybeMasked.empty())
return build(builder, result, vector, memref, indices, permMap,
ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
}
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value memref, ValueRange indices,
AffineMap permutationMap) {
build(builder, result, vector, memref, indices,
/*maybeMasked=*/ArrayRef<bool>{});
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
@ -1477,7 +1517,8 @@ static LogicalResult verify(TransferWriteOp op) {
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
permutationMap)))
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();
return verifyPermutationMap(permutationMap,

View File

@ -564,9 +564,12 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
// Get VectorType for slice 'i'.
auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
xferReadOp.permutation_map(), xferReadOp.padding(),
xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType,
resultTupleType, sizes, strides, indices, rewriter,
@ -620,9 +623,12 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
xferWriteOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
xferWriteOp.permutation_map());
xferWriteOp.permutation_map(),
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, resultVectorType,
sourceTupleType, sizes, strides, indices, rewriter,

View File

@ -918,6 +918,24 @@ func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -
// CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] :
// CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">
func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
%f7 = constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7 {masked = [false]} :
memref<?xf32>, vector<17xf32>
return %f: vector<17xf32>
}
// CHECK-LABEL: func @transfer_read_1d_not_masked
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
//
// 1. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
//
// 2. Rewrite as a load.
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*">
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>

View File

@ -220,14 +220,12 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
// CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32
%f7 = constant 7.0: f32
// CHECK-DAG: %[[cond0:.*]] = constant 1 : i1
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
// CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: scf.if %[[cond1]] {
// CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref<?x?xf32>, vector<15xf32>
// CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>>
@ -253,7 +251,6 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
// CHECK: %[[cond0:.*]] = constant 1 : i1
// CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
// CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
@ -261,8 +258,7 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
// CHECK: affine.for %[[I:.*]] = 0 to 17 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: %[[cond1:.*]] = and %[[cmp]], %[[cond0]] : i1
// CHECK: scf.if %[[cond1]] {
// CHECK: scf.if %[[cmp]] {
// CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
// CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
// CHECK: }
@ -271,3 +267,26 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
vector<17x15xf32>, memref<?x?xf32>
return
}
// -----
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-LABEL: transfer_write_progressive_not_masked(
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
// CHECK-NOT: scf.if
// CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
// CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
// CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 {
// CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
// CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} :
vector<17x15xf32>, memref<?x?xf32>
return
}

View File

@ -348,6 +348,16 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
// -----
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<2x3xf32>
// expected-error@+1 {{ expects the optional masked attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {masked = [false], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
// -----
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant 3.0 : f32

View File

@ -22,6 +22,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : memref<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@ -29,6 +31,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
return
}