[mlir] Add "mask" operand to vector.transfer_read/write.

Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern.

Differential Revision: https://reviews.llvm.org/D100001
This commit is contained in:
Matthias Springer 2021-04-07 21:11:55 +09:00
parent c0ef93bec8
commit 65a3f28939
13 changed files with 389 additions and 186 deletions

View File

@ -68,7 +68,7 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool enableIndexOptimizations = true);
bool reassociateFPReductions = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

View File

@ -88,6 +88,10 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
/// `vector.store` and `vector.broadcast`.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr

View File

@ -1135,10 +1135,12 @@ def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
Optional<VectorOf<[I1]>>:$mask,
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
Results<(outs AnyVector:$vector)> {
@ -1167,13 +1169,19 @@ def Vector_TransferReadOp :
return type.
An SSA value `padding` of the same elemental type as the MemRef/Tensor is
provided to specify a fallback value in the case of out-of-bounds accesses.
provided to specify a fallback value in the case of out-of-bounds accesses
and/or masking.
An optional SSA value `mask` of the same shape as the vector type may be
specified to mask out elements. Such elements will be replaces with
`padding`. Elements whose corresponding mask element is `0` are masked out.
An optional boolean array attribute is provided to specify which dimensions
of the transfer are guaranteed to be within bounds. The absence of this
`in_bounds` attribute signifies that any dimension of the transfer may be
out-of-bounds. A `vector.transfer_read` can be lowered to a simple load if
all dimensions are specified to be within bounds.
all dimensions are specified to be within bounds and no `mask` was
specified.
This operation is called 'read' by opposition to 'load' because the
super-vector granularity is generally not representable with a single
@ -1299,6 +1307,14 @@ def Vector_TransferReadOp :
// 'getMinorIdentityMap' (resp. zero).
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
// Builder that does not set mask.
OpBuilder<(ins "Type":$vector, "Value":$source,
"ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
"ArrayAttr":$inBounds)>,
// Builder that does not set mask.
OpBuilder<(ins "Type":$vector, "Value":$source,
"ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
"ArrayAttr":$inBounds)>
];
let hasFolder = 1;
@ -1308,11 +1324,13 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
Arguments<(ins AnyVector:$vector, AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
Optional<VectorOf<[I1]>>:$mask,
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
@ -1341,11 +1359,16 @@ def Vector_TransferWriteOp :
The size of the slice is specified by the size of the vector.
An optional SSA value `mask` of the same shape as the vector type may be
specified to mask out elements. Elements whose corresponding mask element
is `0` are masked out.
An optional boolean array attribute is provided to specify which dimensions
of the transfer are guaranteed to be within bounds. The absence of this
`in_bounds` attribute signifies that any dimension of the transfer may be
out-of-bounds. A `vector.transfer_write` can be lowered to a simple store
if all dimensions are specified to be within bounds.
if all dimensions are specified to be within bounds and no `mask` was
specified.
This operation is called 'write' by opposition to 'store' because the
super-vector granularity is generally not representable with a single
@ -1391,6 +1414,8 @@ def Vector_TransferWriteOp :
"AffineMap":$permutationMap)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
];

View File

@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
Location loc, Type targetType, Value value) {
if (targetType == value.getType())
return value;
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
return rewriter.create<IndexCastOp>(loc, targetType, value);
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
}
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
// much more compact, IR for this operation, but LLVM eventually
// generates more elaborate instructions for this intrinsic since it
// is very conservative on the boundary conditions.
static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
Operation *op, bool enableIndexOptimizations,
int64_t dim, Value b, Value *off = nullptr) {
auto loc = op->getLoc();
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
Value indices;
Type idxType;
if (enableIndexOptimizations) {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI32VectorAttr(
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
idxType = rewriter.getI32Type();
} else {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
idxType = rewriter.getI64Type();
}
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
align);
return success();
@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
ArrayRef<Value> operands) {
return TransferReadOpAdaptor(operands);
return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
}
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
ArrayRef<Value> operands) {
return TransferWriteOpAdaptor(operands);
return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
}
namespace {
@ -618,33 +558,6 @@ private:
const bool reassociateFPReductions;
};
/// Conversion pattern for a vector.create_mask (1-D only).
class VectorCreateMaskOpConversion
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op.getType();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
dstType.getDimSize(0), operands[0]));
return success();
}
return failure();
}
private:
const bool enableIndexOptimizations;
};
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
@ -1177,20 +1090,12 @@ public:
}
};
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
/// 1. Get the source/dst address as an LLVM vector pointer.
/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
/// Conversion pattern that converts a 1-D vector transfer read/write op into a
/// a masked or unmasked read/write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
@ -1212,6 +1117,9 @@ public:
auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (xferOp.hasOutOfBoundsDim())
return failure();
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
@ -1241,40 +1149,24 @@ public:
#endif // ifndef NDEBUG
}
// 1. Get the source/dst address as an LLVM vector pointer.
// Get the source/dst address as an LLVM vector pointer.
VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
Value vectorDataPtr =
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
if (xferOp.isDimInBounds(0))
// Rewrite as an unmasked masked read / write.
if (!xferOp.mask())
return replaceTransferOpWithLoadOrStore(rewriter,
*this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// 4. Let dim the memref dimension, compute the vector comparison mask
// (in-bounds mask):
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
// Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr, mask);
xferOp, operands, vectorDataPtr,
xferOp.mask());
}
private:
const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@ -1484,17 +1376,13 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool enableIndexOptimizations) {
bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
converter, enableIndexOptimizations);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
converter);
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(

View File

@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
populateVectorToLLVMConversionPatterns(converter, patterns,
reassociateFPReductions);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.

View File

@ -42,7 +42,7 @@ static LogicalResult replaceTransferOpWithMubuf(
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
Value &glc, Value &slc) {
auto adaptor = TransferWriteOpAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
dwordConfig, vindex,
offsetSizeInBytes, glc, slc);
@ -62,7 +62,7 @@ public:
LogicalResult
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename ConcreteOp::Adaptor adaptor(operands);
typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary());
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)

View File

@ -538,6 +538,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
using namespace mlir::edsc::op;
TransferReadOp transfer = cast<TransferReadOp>(op);
if (transfer.mask())
return failure();
auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
@ -624,6 +626,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
using namespace edsc::op;
TransferWriteOp transfer = cast<TransferWriteOp>(op);
if (transfer.mask())
return failure();
auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
if (!memRefType)
return failure();

View File

@ -2295,8 +2295,27 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vectorType, source, indices, permMap, inBounds);
}
/// Builder that does not provide a mask.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
Type vectorType, Value source, ValueRange indices,
AffineMap permutationMap, Value padding,
ArrayAttr inBounds) {
build(builder, result, vectorType, source, indices, permutationMap, padding,
/*mask=*/Value(), inBounds);
}
/// Builder that does not provide a mask.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
Type vectorType, Value source, ValueRange indices,
AffineMapAttr permutationMap, Value padding,
ArrayAttr inBounds) {
build(builder, result, vectorType, source, indices, permutationMap, padding,
/*mask=*/Value(), inBounds);
}
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 2> elidedAttrs;
SmallVector<StringRef, 3> elidedAttrs;
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
if (op.permutation_map().isMinorIdentity())
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideInBounds = true;
@ -2316,27 +2335,36 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.source() << "[" << op.indices()
<< "], " << op.padding();
if (op.mask())
p << ", " << op.mask();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getShapedType() << ", " << op.getVectorType();
}
static ParseResult parseTransferReadOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
llvm::SMLoc typesLoc;
OpAsmParser::OperandType sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
OpAsmParser::OperandType paddingInfo;
SmallVector<Type, 2> types;
OpAsmParser::OperandType maskInfo;
// Parsing with support for paddingValue.
if (parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseComma() || parser.parseOperand(paddingInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseComma() || parser.parseOperand(paddingInfo))
return failure();
ParseResult hasMask = parser.parseOptionalComma();
if (hasMask.succeeded()) {
parser.parseOperand(maskInfo);
}
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = parser.getBuilder().getIndexType();
auto indexType = builder.getIndexType();
auto shapedType = types[0].dyn_cast<ShapedType>();
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
@ -2349,12 +2377,21 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
parser.resolveOperand(paddingInfo, shapedType.getElementType(),
result.operands) ||
parser.addTypeToList(vectorType, result.types));
result.operands))
return failure();
if (hasMask.succeeded()) {
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
result.addAttribute(
TransferReadOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
static_cast<int32_t>(hasMask.succeeded())}));
return parser.addTypeToList(vectorType, result.types);
}
static LogicalResult verify(TransferReadOp op) {
@ -2525,7 +2562,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
/*optional*/ ArrayAttr inBounds) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
inBounds);
/*mask=*/Value(), inBounds);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
@ -2534,24 +2571,39 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
/*optional*/ ArrayAttr inBounds) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
inBounds);
/*mask=*/Value(), inBounds);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMap permutationMap, /*optional*/ Value mask,
/*optional*/ ArrayAttr inBounds) {
Type resultType = source.getType().dyn_cast<RankedTensorType>();
build(builder, result, resultType, vector, source, indices, permutationMap,
mask, inBounds);
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
llvm::SMLoc typesLoc;
OpAsmParser::OperandType vectorInfo, sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<Type, 2> types;
OpAsmParser::OperandType maskInfo;
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
return failure();
ParseResult hasMask = parser.parseOptionalComma();
if (hasMask.succeeded() && parser.parseOperand(maskInfo))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = parser.getBuilder().getIndexType();
auto indexType = builder.getIndexType();
VectorType vectorType = types[0].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
@ -2564,17 +2616,28 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
(shapedType.isa<RankedTensorType>() &&
parser.addTypeToList(shapedType, result.types)));
parser.resolveOperands(indexInfo, indexType, result.operands))
return failure();
if (hasMask.succeeded()) {
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
result.addAttribute(
TransferWriteOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
static_cast<int32_t>(hasMask.succeeded())}));
return failure(shapedType.isa<RankedTensorType>() &&
parser.addTypeToList(shapedType, result.types));
}
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
<< op.indices() << "]";
if (op.mask())
p << ", " << op.mask();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getVectorType() << ", " << op.getShapedType();
}

View File

@ -596,6 +596,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
OpBuilder &builder) {
if (!isIdentitySuffix(readOp.permutation_map()))
return nullptr;
if (readOp.mask())
return nullptr;
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
@ -641,6 +643,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
if (writeOp.mask())
return failure();
VectorType sourceVectorType = writeOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
TupleType tupleType = generateExtractSlicesOpResultType(
@ -722,6 +726,9 @@ public:
if (ignoreFilter && ignoreFilter(readOp))
return failure();
if (readOp.mask())
return failure();
// TODO: Support splitting TransferReadOp with non-identity permutation
// maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(readOp.permutation_map()))
@ -768,6 +775,9 @@ public:
if (ignoreFilter && ignoreFilter(writeOp))
return failure();
if (writeOp.mask())
return failure();
// TODO: Support splitting TransferWriteOp with non-identity permutation
// maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(writeOp.permutation_map()))
@ -2546,6 +2556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
if (xferReadOp.mask())
return failure();
// TODO: add support for write case.
if (!xferReadOp)
return failure();
@ -2677,6 +2690,8 @@ struct TransferReadExtractPattern
dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
if (!extract)
return failure();
if (read.mask())
return failure();
edsc::ScopedContext scope(rewriter, read.getLoc());
using mlir::edsc::op::operator+;
using mlir::edsc::op::operator*;
@ -2712,6 +2727,8 @@ struct TransferWriteInsertPattern
auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
if (!insert)
return failure();
if (write.mask())
return failure();
edsc::ScopedContext scope(rewriter, write.getLoc());
using mlir::edsc::op::operator+;
using mlir::edsc::op::operator*;
@ -2742,6 +2759,7 @@ struct TransferWriteInsertPattern
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
/// - The op has no mask.
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context)
@ -2780,7 +2798,8 @@ struct TransferReadToVectorLoadLowering
// MaskedLoadOp.
if (read.hasOutOfBoundsDim())
return failure();
if (read.mask())
return failure();
Operation *loadOp;
if (!broadcastedDims.empty() &&
unbroadcastedVectorType.getNumElements() == 1) {
@ -2815,6 +2834,7 @@ struct TransferReadToVectorLoadLowering
/// type of the written value.
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
/// - The op has no mask.
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context)
@ -2840,6 +2860,8 @@ struct TransferWriteToVectorStoreLowering
// MaskedStoreOp.
if (write.hasOutOfBoundsDim())
return failure();
if (write.mask())
return failure();
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.vector(), write.source(), write.indices());
return success();
@ -2880,6 +2902,8 @@ struct TransferReadPermutationLowering
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return failure();
if (op.mask())
return failure();
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
@ -2914,6 +2938,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
if (op.mask())
return failure();
AffineMap map = op.permutation_map();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
@ -3062,6 +3088,9 @@ struct CastAwayTransferReadLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
if (read.mask())
return failure();
auto shapedType = read.source().getType().cast<ShapedType>();
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@ -3102,6 +3131,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
if (write.mask())
return failure();
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@ -3371,6 +3403,151 @@ struct BubbleUpBitCastForStridedSliceInsert
}
};
static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
Type targetType, Value value) {
if (targetType == value.getType())
return value;
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
return rewriter.create<IndexCastOp>(loc, targetType, value);
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
}
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
// much more compact, IR for this operation, but LLVM eventually
// generates more elaborate instructions for this intrinsic since it
// is very conservative on the boundary conditions.
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
bool enableIndexOptimizations, int64_t dim,
Value b, Value *off = nullptr) {
auto loc = op->getLoc();
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
Value indices;
Type idxType;
if (enableIndexOptimizations) {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI32VectorAttr(
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
idxType = rewriter.getI32Type();
} else {
indices = rewriter.create<ConstantOp>(
loc, rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
idxType = rewriter.getI64Type();
}
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
template <typename ConcreteOp>
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
public:
explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
: mlir::OpRewritePattern<ConcreteOp>(context),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult matchAndRewrite(ConcreteOp xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp.hasOutOfBoundsDim())
return failure();
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)
return failure();
Location loc = xferOp->getLoc();
VectorType vtp = xferOp.getVectorType();
// * Create a vector with linear indices [ 0 .. vector_length - 1 ].
// * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// * Let dim the memref dimension, compute the vector comparison mask
// (in-bounds mask):
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = vtp.getNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
if (xferOp.mask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
mask = rewriter.create<AndOp>(loc, mask, xferOp.mask());
}
rewriter.updateRootInPlace(xferOp, [&]() {
xferOp.maskMutable().assign(mask);
xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
});
return success();
}
private:
const bool enableIndexOptimizations;
};
/// Conversion pattern for a vector.create_mask (1-D only).
class VectorCreateMaskOpConversion
: public OpRewritePattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
bool enableIndexOpt)
: mlir::OpRewritePattern<vector::CreateMaskOp>(context),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
int64_t rank = dstType.getRank();
if (rank == 1) {
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
dstType.getDimSize(0), op.getOperand(0)));
return success();
}
return failure();
}
private:
const bool enableIndexOptimizations;
};
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool enableIndexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
MaterializeTransferMask<vector::TransferReadOp>,
MaterializeTransferMask<vector::TransferWriteOp>>(
patterns.getContext(), enableIndexOptimizations);
}
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(

View File

@ -3,20 +3,19 @@
// CMP32-LABEL: @genbool_var_1d(
// CMP32-SAME: %[[ARG:.*]]: index)
// CMP32: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
// CMP32: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
// CMP32: %[[T1:.*]] = trunci %[[A]] : i64 to i32
// CMP32: %[[T1:.*]] = index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
// CMP32: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
// CMP32: return %[[T3]] : vector<11xi1>
// CMP64-LABEL: @genbool_var_1d(
// CMP64-SAME: %[[ARG:.*]]: index)
// CMP64: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
// CMP64: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
// CMP64: %[[T1:.*]] = splat %[[A]] : vector<11xi64>
// CMP64: %[[T2:.*]] = cmpi slt, %[[T0]], %[[T1]] : vector<11xi64>
// CMP64: return %[[T2]] : vector<11xi1>
// CMP64: %[[T1:.*]] = index_cast %[[ARG]] : index to i64
// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
// CMP64: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
// CMP64: return %[[T3]] : vector<11xi1>
func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>

View File

@ -1049,31 +1049,31 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-LABEL: func @transfer_read_1d
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
// CHECK: %[[c7:.*]] = constant 7.0
//
// 1. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi32>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[otrunc:.*]] = index_cast %[[BASE]] : index to i32
// CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
// CHECK: %[[offsetVec2:.*]] = addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[dtrunc:.*]] = index_cast %[[DIM]] : index to i32
// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
// CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
//
// 4. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
//
// 5. Rewrite as a masked read.
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
@ -1081,26 +1081,26 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
//
// 1. Bitcast to vector form.
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex_b:.*]] = constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi32>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: addi
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: %[[mask_b:.*]] = cmpi slt, {{.*}} : vector<17xi32>
//
// 4. Bitcast to vector form.
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
//
// 5. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 4 : i32} :
@ -1182,6 +1182,21 @@ func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf
// -----
// CHECK-LABEL: func @transfer_read_1d_mask
// CHECK: %[[mask1:.*]] = constant dense<[false, false, true, false, true]>
// CHECK: %[[cmpi:.*]] = cmpi slt
// CHECK: %[[mask2:.*]] = and %[[cmpi]], %[[mask1]]
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
// CHECK: return %[[r]]
func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
%f7 = constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
return %f: vector<5xf32>
}
// -----
func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
%c0 = constant 0: i32
%v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :

View File

@ -11,6 +11,7 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%c0 = constant 0 : i32
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
//
// CHECK: vector.transfer_read
@ -27,7 +28,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
%7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@ -39,7 +41,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref<?x?xf32>
vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
return
}

View File

@ -12,6 +12,14 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
return
}
func @transfer_read_mask_1d(%A : memref<?xf32>, %base: index) {
%fm42 = constant -42.0: f32
%m = constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1>
%f = vector.transfer_read %A[%base], %fm42, %m : memref<?xf32>, vector<13xf32>
vector.print %f: vector<13xf32>
return
}
func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base], %fm42
@ -21,6 +29,15 @@ func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
return
}
func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
%fm42 = constant -42.0: f32
%m = constant dense<[0, 1, 0, 1]> : vector<4xi1>
%f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]}
: memref<?xf32>, vector<4xf32>
vector.print %f: vector<4xf32>
return
}
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4xf32>
@ -47,6 +64,8 @@ func @entry() {
// Read shifted by 2 and pad with -42:
// ( 2, 3, 4, -42, ..., -42)
call @transfer_read_1d(%A, %c2) : (memref<?xf32>, index) -> ()
// Read with mask and out-of-bounds access.
call @transfer_read_mask_1d(%A, %c2) : (memref<?xf32>, index) -> ()
// Write into memory shifted by 3
// memory contains [[ 0, 1, 2, 0, 0, xxx garbage xxx ]]
call @transfer_write_1d(%A, %c3) : (memref<?xf32>, index) -> ()
@ -56,9 +75,13 @@ func @entry() {
// Read in-bounds 4 @ 1, guaranteed to not overflow.
// Exercises proper alignment.
call @transfer_read_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
// Read in-bounds with mask.
call @transfer_read_mask_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
return
}
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( -42, -42, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 1, 2, 0, 0 )
// CHECK: ( -42, 2, -42, 0 )