forked from OSchip/llvm-project
[mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store
Summary: This revision adds support to lower 1-D vector transfers to LLVM. A mask of the vector length is created that compares the base offset + linear index to the dim of the vector. In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1. A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and letting other conversions kick in. Differential Revision: https://reviews.llvm.org/D77703
This commit is contained in:
parent
413467f9ec
commit
8345b86d9a
|
@ -398,6 +398,29 @@ public:
|
|||
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
|
||||
uint64_t value) const;
|
||||
|
||||
// Given subscript indices and array sizes in row-major order,
|
||||
// i_n, i_{n-1}, ..., i_1
|
||||
// s_n, s_{n-1}, ..., s_1
|
||||
// obtain a value that corresponds to the linearized subscript
|
||||
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
|
||||
// by accumulating the running linearized value.
|
||||
// Note that `indices` and `allocSizes` are passed in the same order as they
|
||||
// appear in load/store operations and memref type declarations.
|
||||
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
|
||||
ArrayRef<Value> indices,
|
||||
ArrayRef<Value> allocSizes) const;
|
||||
|
||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
Value getStridedElementPtr(Location loc, Type elementTypePtr,
|
||||
Value descriptor, ArrayRef<Value> indices,
|
||||
ArrayRef<int64_t> strides, int64_t offset,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
|
||||
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
|
||||
llvm::Module &module) const;
|
||||
|
||||
protected:
|
||||
/// Reference to the type converter, with potential extensions.
|
||||
LLVMTypeConverter &typeConverter;
|
||||
|
|
|
@ -73,6 +73,7 @@ public:
|
|||
|
||||
/// Vector type utilities.
|
||||
LLVMType getVectorElementType();
|
||||
unsigned getVectorNumElements();
|
||||
bool isVectorTy();
|
||||
|
||||
/// Function type utilities.
|
||||
|
|
|
@ -111,6 +111,7 @@ public:
|
|||
IntegerAttr getI16IntegerAttr(int16_t value);
|
||||
IntegerAttr getI32IntegerAttr(int32_t value);
|
||||
IntegerAttr getI64IntegerAttr(int64_t value);
|
||||
IntegerAttr getIndexAttr(int64_t value);
|
||||
|
||||
/// Signed and unsigned integer attribute getters.
|
||||
IntegerAttr getSI32IntegerAttr(int32_t value);
|
||||
|
|
|
@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
|
|||
return createIndexAttrConstant(builder, loc, getIndexType(), value);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::linearizeSubscripts(
|
||||
ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
|
||||
ArrayRef<Value> allocSizes) const {
|
||||
assert(indices.size() == allocSizes.size() &&
|
||||
"mismatching number of indices and allocation sizes");
|
||||
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
|
||||
|
||||
Value linearized = indices.front();
|
||||
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
|
||||
linearized = builder.create<LLVM::MulOp>(
|
||||
loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
|
||||
linearized = builder.create<LLVM::AddOp>(
|
||||
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
|
||||
}
|
||||
return linearized;
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getStridedElementPtr(
|
||||
Location loc, Type elementTypePtr, Value descriptor,
|
||||
ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefDescriptor memRefDescriptor(descriptor);
|
||||
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
|
||||
? memRefDescriptor.offset(rewriter, loc)
|
||||
: this->createIndexConstant(rewriter, loc, offset);
|
||||
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: this->createIndexConstant(rewriter, loc, strides[i]);
|
||||
Value additionalOffset =
|
||||
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
||||
offsetValue =
|
||||
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
|
||||
}
|
||||
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
|
||||
Value memRefDesc,
|
||||
ArrayRef<Value> indices,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
llvm::Module &module) const {
|
||||
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
|
||||
offset, rewriter);
|
||||
}
|
||||
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
|
@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|||
MemRefType type = cast<Derived>(op).getMemRefType();
|
||||
return isSupportedMemRefType(type) ? success() : failure();
|
||||
}
|
||||
|
||||
// Given subscript indices and array sizes in row-major order,
|
||||
// i_n, i_{n-1}, ..., i_1
|
||||
// s_n, s_{n-1}, ..., s_1
|
||||
// obtain a value that corresponds to the linearized subscript
|
||||
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
|
||||
// by accumulating the running linearized value.
|
||||
// Note that `indices` and `allocSizes` are passed in the same order as they
|
||||
// appear in load/store operations and memref type declarations.
|
||||
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
|
||||
ArrayRef<Value> indices,
|
||||
ArrayRef<Value> allocSizes) const {
|
||||
assert(indices.size() == allocSizes.size() &&
|
||||
"mismatching number of indices and allocation sizes");
|
||||
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
|
||||
|
||||
Value linearized = indices.front();
|
||||
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
|
||||
linearized = builder.create<LLVM::MulOp>(
|
||||
loc, this->getIndexType(),
|
||||
ArrayRef<Value>{linearized, allocSizes[i]});
|
||||
linearized = builder.create<LLVM::AddOp>(
|
||||
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
|
||||
}
|
||||
return linearized;
|
||||
}
|
||||
|
||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
Value getStridedElementPtr(Location loc, Type elementTypePtr,
|
||||
Value descriptor, ArrayRef<Value> indices,
|
||||
ArrayRef<int64_t> strides, int64_t offset,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefDescriptor memRefDescriptor(descriptor);
|
||||
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
|
||||
? memRefDescriptor.offset(rewriter, loc)
|
||||
: this->createIndexConstant(rewriter, loc, offset);
|
||||
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: this->createIndexConstant(rewriter, loc, strides[i]);
|
||||
Value additionalOffset =
|
||||
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
||||
offsetValue =
|
||||
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
|
||||
}
|
||||
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
|
||||
}
|
||||
|
||||
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
|
||||
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
|
||||
llvm::Module &module) const {
|
||||
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
|
||||
offset, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// Load operation is lowered to obtaining a pointer to the indexed element
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
@ -894,6 +895,129 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteOp>
|
||||
void replaceTransferOp(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Operation *op, ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask);
|
||||
|
||||
template <>
|
||||
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
Location loc, Operation *op,
|
||||
ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask) {
|
||||
auto xferOp = cast<TransferReadOp>(op);
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
|
||||
VectorType fillType = xferOp.getVectorType();
|
||||
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
|
||||
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
|
||||
|
||||
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
op, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
rewriter.getI32IntegerAttr(1));
|
||||
}
|
||||
|
||||
template <>
|
||||
void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
Location loc, Operation *op,
|
||||
ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask) {
|
||||
auto adaptor = TransferWriteOpOperandAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
|
||||
}
|
||||
|
||||
static TransferReadOpOperandAdaptor
|
||||
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
|
||||
return TransferReadOpOperandAdaptor(operands);
|
||||
}
|
||||
|
||||
static TransferWriteOpOperandAdaptor
|
||||
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
|
||||
return TransferWriteOpOperandAdaptor(operands);
|
||||
}
|
||||
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
|
||||
/// sequence of:
|
||||
/// 1. Bitcast to vector form.
|
||||
/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
/// 3. Create a mask where offsetVector is compared against memref upper bound.
|
||||
/// 4. Rewrite op as a masked read or write.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorTransferConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConv)
|
||||
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
|
||||
typeConv) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto xferOp = cast<ConcreteOp>(op);
|
||||
auto adaptor = getTransferOpAdapter(xferOp, operands);
|
||||
if (xferOp.getMemRefType().getRank() != 1)
|
||||
return failure();
|
||||
if (!xferOp.permutation_map().isIdentity())
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Type i64Type = rewriter.getIntegerType(64);
|
||||
MemRefType memRefType = xferOp.getMemRefType();
|
||||
|
||||
// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
// TODO: support alignment when possible.
|
||||
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter, getModule());
|
||||
auto vecTy =
|
||||
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
auto vectorDataPtr =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
|
||||
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
unsigned vecWidth = vecTy.getVectorNumElements();
|
||||
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
|
||||
SmallVector<int64_t, 8> indices;
|
||||
indices.reserve(vecWidth);
|
||||
for (unsigned i = 0; i < vecWidth; ++i)
|
||||
indices.push_back(i);
|
||||
Value linearIndices = rewriter.create<ConstantOp>(
|
||||
loc, vectorCmpType,
|
||||
DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
|
||||
linearIndices = rewriter.create<LLVM::DialectCastOp>(
|
||||
loc, toLLVMTy(vectorCmpType), linearIndices);
|
||||
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
Value offsetIndex = *(xferOp.indices().begin());
|
||||
offsetIndex = rewriter.create<IndexCastOp>(
|
||||
loc, vectorCmpType.getElementType(), offsetIndex);
|
||||
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
|
||||
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
|
||||
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
|
||||
dim =
|
||||
rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
|
||||
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
|
||||
Value mask =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
|
||||
mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
|
||||
mask);
|
||||
|
||||
// 5. Rewrite as a masked read / write.
|
||||
replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
|
||||
vectorDataPtr, mask);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorPrintOpConversion(MLIRContext *context,
|
||||
|
@ -1079,16 +1203,25 @@ public:
|
|||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
// clang-format off
|
||||
patterns.insert<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorStridedSliceOpConversion>(ctx);
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
|
||||
VectorShuffleOpConversion, VectorExtractElementOpConversion,
|
||||
VectorExtractOpConversion, VectorFMAOp1DConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorTypeCastOpConversion, VectorPrintOpConversion>(
|
||||
ctx, converter);
|
||||
patterns
|
||||
.insert<VectorBroadcastOpConversion,
|
||||
VectorReductionOpConversion,
|
||||
VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion,
|
||||
VectorExtractOpConversion,
|
||||
VectorFMAOp1DConversion,
|
||||
VectorInsertElementOpConversion,
|
||||
VectorInsertOpConversion,
|
||||
VectorPrintOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>,
|
||||
VectorTypeCastOpConversion>(ctx, converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
|
|
@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
|
|||
LLVMType LLVMType::getVectorElementType() {
|
||||
return get(getContext(), getUnderlyingType()->getVectorElementType());
|
||||
}
|
||||
unsigned LLVMType::getVectorNumElements() {
|
||||
return getUnderlyingType()->getVectorNumElements();
|
||||
}
|
||||
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
|
||||
|
||||
/// Function type utilities.
|
||||
|
|
|
@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
|
|||
return DictionaryAttr::get(value, context);
|
||||
}
|
||||
|
||||
IntegerAttr Builder::getIndexAttr(int64_t value) {
|
||||
return IntegerAttr::get(getIndexType(), APInt(64, value));
|
||||
}
|
||||
|
||||
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
|
||||
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
|
||||
}
|
||||
|
|
|
@ -738,3 +738,95 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
|
|||
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
|
||||
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
|
||||
// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
|
||||
|
||||
func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
||||
%f7 = constant 7.0: f32
|
||||
%f = vector.transfer_read %A[%base], %f7
|
||||
{permutation_map = affine_map<(d0) -> (d0)>} :
|
||||
memref<?xf32>, vector<17xf32>
|
||||
vector.transfer_write %f, %A[%base]
|
||||
{permutation_map = affine_map<(d0) -> (d0)>} :
|
||||
vector<17xf32>, memref<?xf32>
|
||||
return %f: vector<17xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @transfer_read_1d
|
||||
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
|
||||
//
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
|
||||
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
|
||||
// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
|
||||
//
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
|
||||
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK: %[[offsetVec2:.*]] = llvm.insertelement %[[BASE]], %[[offsetVec]][%[[c0]] :
|
||||
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
|
||||
// CHECK: %[[offsetVec3:.*]] = llvm.shufflevector %[[offsetVec2]], %{{.*}} [
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
|
||||
// CHECK: %[[offsetVec4:.*]] = llvm.add %[[offsetVec3]], %[[linearIndex]] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">
|
||||
//
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
|
||||
// CHECK-SAME: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
|
||||
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
|
||||
// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
|
||||
// CHECK: %[[dimVec3:.*]] = llvm.shufflevector %[[dimVec2]], %{{.*}} [
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
|
||||
// CHECK: %[[mask:.*]] = llvm.icmp "slt" %[[offsetVec4]], %[[dimVec3]] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">
|
||||
//
|
||||
// 5. Rewrite as a masked read.
|
||||
// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
|
||||
// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
|
||||
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
|
||||
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
|
||||
// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
|
||||
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
|
||||
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
|
||||
//
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
|
||||
// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
|
||||
// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
|
||||
//
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
|
||||
// CHECK: llvm.add
|
||||
//
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
|
||||
// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
|
||||
// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
|
||||
// CHECK: %[[mask_b:.*]] = llvm.icmp "slt" {{.*}} : !llvm<"<17 x i64>">
|
||||
//
|
||||
// 5. Rewrite as a masked write.
|
||||
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
|
||||
// CHECK-SAME: {alignment = 1 : i32} :
|
||||
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
|
||||
|
|
Loading…
Reference in New Issue