[mlir][vector] set alignment when lowering transfer_read and transfer_write.

When emitting masked load / store, set alignment from data layout.

Differential Revision: https://reviews.llvm.org/D79246
This commit is contained in:
Wen-Heng (Jack) Chung 2020-05-07 11:39:15 +02:00 committed by Alex Zinenko
parent a31f4c52bf
commit a23f190213
2 changed files with 30 additions and 8 deletions

View File

@ -752,6 +752,19 @@ void replaceTransferOp(ConversionPatternRewriter &rewriter,
Operation *op, ArrayRef<Value> operands, Value dataPtr,
Value mask);
LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
Type type, LLVM::LLVMType &llvmType,
unsigned &align) {
auto convertedType = typeConverter.convertType(type);
if (!convertedType)
return failure();
llvmType = convertedType.template cast<LLVM::LLVMType>();
auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
return success();
}
template <>
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
@ -764,10 +777,13 @@ void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
op, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(1));
LLVM::LLVMType vecTy;
unsigned align;
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
vecTy, align)))
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
op, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(align));
}
template <>
@ -777,8 +793,14 @@ void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
auto xferOp = cast<TransferWriteOp>(op);
LLVM::LLVMType vecTy;
unsigned align;
if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
vecTy, align)))
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
}
static TransferReadOpOperandAdaptor

View File

@ -818,7 +818,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} :
// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
//
@ -850,7 +850,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
//
// 5. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 1 : i32} :
// CHECK-SAME: {alignment = 128 : i32} :
// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {