[mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store

Summary:
This revision adds support to lower 1-D vector transfers to LLVM.
A mask of the vector length is created that compares the base offset + linear index to the dim of the vector.
In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1.

A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and
letting other conversions kick in.

Differential Revision: https://reviews.llvm.org/D77703
This commit is contained in:
Nicolas Vasilache 2020-04-09 16:16:32 -04:00
parent 413467f9ec
commit 8345b86d9a
8 changed files with 318 additions and 70 deletions

View File

@ -398,6 +398,29 @@ public:
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const;
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const;
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, Type elementTypePtr,
Value descriptor, ArrayRef<Value> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const;
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;

View File

@ -73,6 +73,7 @@ public:
/// Vector type utilities.
LLVMType getVectorElementType();
unsigned getVectorNumElements();
bool isVectorTy();
/// Function type utilities.

View File

@ -111,6 +111,7 @@ public:
IntegerAttr getI16IntegerAttr(int16_t value);
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
IntegerAttr getIndexAttr(int64_t value);
/// Signed and unsigned integer attribute getters.
IntegerAttr getSI32IntegerAttr(int32_t value);

View File

@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}
Value ConvertToLLVMPattern::linearizeSubscripts(
ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
Value linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
}
return linearized;
}
Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, Type elementTypePtr, Value descriptor,
ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
: this->createIndexConstant(rewriter, loc, strides[i]);
Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
Value memRefDesc,
ArrayRef<Value> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
offset, rewriter);
}
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? success() : failure();
}
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
Value linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
ArrayRef<Value>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
}
return linearized;
}
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, Type elementTypePtr,
Value descriptor, ArrayRef<Value> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
: this->createIndexConstant(rewriter, loc, strides[i]);
Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
offset, rewriter);
}
};
// Load operation is lowered to obtaining a pointer to the indexed element

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
@ -894,6 +895,129 @@ public:
}
};
template <typename ConcreteOp>
void replaceTransferOp(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Operation *op, ArrayRef<Value> operands, Value dataPtr,
Value mask);
template <>
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, Operation *op,
ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto xferOp = cast<TransferReadOp>(op);
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
VectorType fillType = xferOp.getVectorType();
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
op, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(1));
}
template <>
void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, Operation *op,
ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
}
static TransferReadOpOperandAdaptor
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
return TransferReadOpOperandAdaptor(operands);
}
static TransferWriteOpOperandAdaptor
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
return TransferWriteOpOperandAdaptor(operands);
}
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
/// 1. Bitcast to vector form.
/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 3. Create a mask where offsetVector is compared against memref upper bound.
/// 4. Rewrite op as a masked read or write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
public:
explicit VectorTransferConversion(MLIRContext *context,
LLVMTypeConverter &typeConv)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
typeConv) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
if (xferOp.getMemRefType().getRank() != 1)
return failure();
if (!xferOp.permutation_map().isIdentity())
return failure();
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
Location loc = op->getLoc();
Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();
// 1. Get the source/dst address as an LLVM vector pointer.
// TODO: support alignment when possible.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
auto vectorDataPtr =
rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
unsigned vecWidth = vecTy.getVectorNumElements();
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
SmallVector<int64_t, 8> indices;
indices.reserve(vecWidth);
for (unsigned i = 0; i < vecWidth; ++i)
indices.push_back(i);
Value linearIndices = rewriter.create<ConstantOp>(
loc, vectorCmpType,
DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
linearIndices = rewriter.create<LLVM::DialectCastOp>(
loc, toLLVMTy(vectorCmpType), linearIndices);
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
Value offsetIndex = *(xferOp.indices().begin());
offsetIndex = rewriter.create<IndexCastOp>(
loc, vectorCmpType.getElementType(), offsetIndex);
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
dim =
rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
Value mask =
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
mask);
// 5. Rewrite as a masked read / write.
replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
vectorDataPtr, mask);
return success();
}
};
class VectorPrintOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorPrintOpConversion(MLIRContext *context,
@ -1079,16 +1203,25 @@ public:
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
VectorShuffleOpConversion, VectorExtractElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorTypeCastOpConversion, VectorPrintOpConversion>(
ctx, converter);
patterns
.insert<VectorBroadcastOpConversion,
VectorReductionOpConversion,
VectorShuffleOpConversion,
VectorExtractElementOpConversion,
VectorExtractOpConversion,
VectorFMAOp1DConversion,
VectorInsertElementOpConversion,
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion>(ctx, converter);
// clang-format on
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(

View File

@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
LLVMType LLVMType::getVectorElementType() {
return get(getContext(), getUnderlyingType()->getVectorElementType());
}
unsigned LLVMType::getVectorNumElements() {
return getUnderlyingType()->getVectorNumElements();
}
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
/// Function type utilities.

View File

@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
return DictionaryAttr::get(value, context);
}
IntegerAttr Builder::getIndexAttr(int64_t value) {
return IntegerAttr::get(getIndexType(), APInt(64, value));
}
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}

View File

@ -738,3 +738,95 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
%f7 = constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7
{permutation_map = affine_map<(d0) -> (d0)>} :
memref<?xf32>, vector<17xf32>
vector.transfer_write %f, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>} :
vector<17xf32>, memref<?xf32>
return %f: vector<17xf32>
}
// CHECK-LABEL: func @transfer_read_1d
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
//
// 1. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[offsetVec2:.*]] = llvm.insertelement %[[BASE]], %[[offsetVec]][%[[c0]] :
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
// CHECK: %[[offsetVec3:.*]] = llvm.shufflevector %[[offsetVec2]], %{{.*}} [
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
// CHECK: %[[offsetVec4:.*]] = llvm.add %[[offsetVec3]], %[[linearIndex]] :
// CHECK-SAME: !llvm<"<17 x i64>">
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
// CHECK-SAME: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
// CHECK: %[[dimVec3:.*]] = llvm.shufflevector %[[dimVec2]], %{{.*}} [
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
// CHECK: %[[mask:.*]] = llvm.icmp "slt" %[[offsetVec4]], %[[dimVec3]] :
// CHECK-SAME: !llvm<"<17 x i64>">
//
// 5. Rewrite as a masked read.
// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
//
// 1. Bitcast to vector form.
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
// CHECK: llvm.add
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
// CHECK: %[[mask_b:.*]] = llvm.icmp "slt" {{.*}} : !llvm<"<17 x i64>">
//
// 5. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 1 : i32} :
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">