forked from OSchip/llvm-project
[mlir][vector] set alignment when lowering transfer_read and transfer_write.
When emitting masked load / store, set alignment from data layout. Differential Revision: https://reviews.llvm.org/D79246
This commit is contained in:
parent
a31f4c52bf
commit
a23f190213
|
@ -752,6 +752,19 @@ void replaceTransferOp(ConversionPatternRewriter &rewriter,
|
|||
Operation *op, ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask);
|
||||
|
||||
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
|
||||
Type type, LLVM::LLVMType &llvmType,
|
||||
unsigned &align) {
|
||||
auto convertedType = typeConverter.convertType(type);
|
||||
if (!convertedType)
|
||||
return failure();
|
||||
|
||||
llvmType = convertedType.template cast<LLVM::LLVMType>();
|
||||
auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
|
||||
align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
|
@ -764,10 +777,13 @@ void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
|
|||
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
|
||||
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
|
||||
|
||||
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
op, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
rewriter.getI32IntegerAttr(1));
|
||||
LLVM::LLVMType vecTy;
|
||||
unsigned align;
|
||||
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
|
||||
vecTy, align)))
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
op, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -777,8 +793,14 @@ void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
|
|||
ArrayRef<Value> operands, Value dataPtr,
|
||||
Value mask) {
|
||||
auto adaptor = TransferWriteOpOperandAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
|
||||
|
||||
auto xferOp = cast<TransferWriteOp>(op);
|
||||
LLVM::LLVMType vecTy;
|
||||
unsigned align;
|
||||
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
|
||||
vecTy, align)))
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
|
||||
}
|
||||
|
||||
static TransferReadOpOperandAdaptor
|
||||
|
|
|
@ -818,7 +818,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
|||
// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
|
||||
// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
|
||||
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
|
||||
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
|
||||
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} :
|
||||
// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
|
||||
|
||||
//
|
||||
|
@ -850,7 +850,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
|||
//
|
||||
// 5. Rewrite as a masked write.
|
||||
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
|
||||
// CHECK-SAME: {alignment = 1 : i32} :
|
||||
// CHECK-SAME: {alignment = 128 : i32} :
|
||||
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
|
||||
|
||||
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
|
||||
|
|
Loading…
Reference in New Issue