[mlir] [VectorOps] Add masked load/store operations to Vector dialect

The intrinsics were already supported and vector.transfer_read/write lowered
direclty into these operations. By providing them as individual ops, however,
clients can used them directly, and it opens up progressively lowering transfer
operations at higher levels (rather than direct lowering to LLVM IR as done now).

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D85357
This commit is contained in:
aartbik 2020-08-05 13:43:16 -07:00
parent dd892a33e1
commit 39379916a7
8 changed files with 431 additions and 0 deletions

View File

@ -1150,6 +1150,102 @@ def Vector_TransferWriteOp :
let hasFolder = 1;
}
def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
let summary = "loads elements from memory into a vector as defined by a mask vector";
let description = [{
The masked load reads elements from memory into a 1-D vector as defined
by a base and a 1-D mask vector. When the mask is set, the element is read
from memory. Otherwise, the corresponding element is taken from a 1-D
pass-through vector. Informally the semantics are:
```
result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
etc.
```
The masked load can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for a masked load. The semantics of the operation
closely correspond to those of the `llvm.masked.load`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
Example:
```mlir
%0 = vector.maskedload %base, %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return result().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
}
def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
let description = [{
The masked store operation writes elements from a 1-D vector into memory
as defined by a base and a 1-D mask vector. When the mask is set, the
corresponding element from the vector is written to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
```
if (mask[0]) MEM[base+0] = value[0]
if (mask[1]) MEM[base+1] = value[1]
etc.
```
The masked store can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for a masked store. The semantics of the operation
closely correspond to those of the `llvm.masked.store`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
Example:
```mlir
vector.maskedstore %base, %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getValueVectorType() {
return value().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
"type($mask) `,` type($value) `into` type($base)";
}
def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins AnyMemRef:$base,

View File

@ -0,0 +1,66 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%ld = vector.maskedload %base, %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
func @entry() {
// Set up memory.
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%A = alloc(%c16) : memref<?xf32>
scf.for %i = %c0 to %c16 step %c1 {
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
store %fi, %A[%i] : memref<?xf32>
}
// Set up pass thru vector.
%u = constant -7.0: f32
%pass = vector.broadcast %u : f32 to vector<16xf32>
// Set up masks.
%f = constant 0: i1
%t = constant 1: i1
%none = vector.constant_mask [0] : vector<16xi1>
%all = vector.constant_mask [16] : vector<16xi1>
%some = vector.constant_mask [8] : vector<16xi1>
%0 = vector.insert %f, %some[0] : i1 into vector<16xi1>
%1 = vector.insert %t, %0[13] : i1 into vector<16xi1>
%2 = vector.insert %t, %1[14] : i1 into vector<16xi1>
%other = vector.insert %t, %2[14] : i1 into vector<16xi1>
//
// Masked load tests.
//
%l1 = call @maskedload16(%A, %none, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %l1 : vector<16xf32>
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
%l2 = call @maskedload16(%A, %all, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %l2 : vector<16xf32>
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
%l3 = call @maskedload16(%A, %some, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %l3 : vector<16xf32>
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, -7, -7, -7 )
%l4 = call @maskedload16(%A, %other, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %l4 : vector<16xf32>
// CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
return
}

View File

@ -0,0 +1,89 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
func @maskedstore16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
vector.maskedstore %base, %mask, %value
: vector<16xi1>, vector<16xf32> into memref<?xf32>
return
}
func @printmem16(%A: memref<?xf32>) {
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%z = constant 0.0: f32
%m = vector.broadcast %z : f32 to vector<16xf32>
%mem = scf.for %i = %c0 to %c16 step %c1
iter_args(%m_iter = %m) -> (vector<16xf32>) {
%c = load %A[%i] : memref<?xf32>
%i32 = index_cast %i : index to i32
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
scf.yield %m_new : vector<16xf32>
}
vector.print %mem : vector<16xf32>
return
}
func @entry() {
// Set up memory.
%f0 = constant 0.0: f32
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%A = alloc(%c16) : memref<?xf32>
scf.for %i = %c0 to %c16 step %c1 {
store %f0, %A[%i] : memref<?xf32>
}
// Set up value vector.
%v = vector.broadcast %f0 : f32 to vector<16xf32>
%val = scf.for %i = %c0 to %c16 step %c1
iter_args(%v_iter = %v) -> (vector<16xf32>) {
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
%v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
scf.yield %v_new : vector<16xf32>
}
// Set up masks.
%t = constant 1: i1
%none = vector.constant_mask [0] : vector<16xi1>
%some = vector.constant_mask [8] : vector<16xi1>
%more = vector.insert %t, %some[13] : i1 into vector<16xi1>
%all = vector.constant_mask [16] : vector<16xi1>
//
// Masked store tests.
//
vector.print %val : vector<16xf32>
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
call @maskedstore16(%A, %none, %val)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
call @maskedstore16(%A, %some, %val)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0 )
call @maskedstore16(%A, %more, %val)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 13, 0, 0 )
call @maskedstore16(%A, %all, %val)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
return
}

View File

@ -163,6 +163,19 @@ LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
// Helper that returns a bit-casted pointer given a memref base.
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Type type,
Value &ptr) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
// Helper that returns vector of pointers given a memref base and an index
// vector.
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
@ -298,6 +311,72 @@ public:
}
};
/// Conversion pattern for a vector.maskedload.
class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorMaskedLoadOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto load = cast<vector::MaskedLoadOp>(op);
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(typeConverter, load, align)))
return failure();
auto vtype = typeConverter.convertType(load.getResultVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
vtype, ptr)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
rewriter.getI32IntegerAttr(align));
return success();
}
};
/// Conversion pattern for a vector.maskedstore.
class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorMaskedStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto store = cast<vector::MaskedStoreOp>(op);
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(typeConverter, store, align)))
return failure();
auto vtype = typeConverter.convertType(store.getValueVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
vtype, ptr)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
rewriter.getI32IntegerAttr(align));
return success();
}
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion : public ConvertToLLVMPattern {
public:
@ -1342,6 +1421,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,

View File

@ -1855,6 +1855,41 @@ Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(MaskedLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
VectorType resVType = op.getResultVectorType();
if (resVType.getElementType() != op.getMemRefType().getElementType())
return op.emitOpError("base and result element type should match");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
return op.emitOpError("expected pass_thru of same type as result type");
return success();
}
//===----------------------------------------------------------------------===//
// MaskedStoreOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(MaskedStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
if (valueVType.getElementType() != op.getMemRefType().getElementType())
return op.emitOpError("base and value element type should match");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected value dim to match mask dim");
return success();
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//

View File

@ -970,6 +970,26 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
// CHECK-SAME: !llvm.vec<16 x float> into !llvm.vec<16 x float>
// CHECK: llvm.return %[[T]] : !llvm.vec<16 x float>
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @masked_load_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x float>>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float>
// CHECK: llvm.return %[[L]] : !llvm.vec<16 x float>
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
return
}
// CHECK-LABEL: func @masked_store_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x float>>
// CHECK: llvm.return
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
%0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
return %0 : vector<3xf32>

View File

@ -1180,6 +1180,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
// -----
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
// expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
// expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
}
// -----
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
}
// -----
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
}
// -----
func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
// expected-error@+1 {{'vector.gather' op base and result element type should match}}
%0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>

View File

@ -369,6 +369,15 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
return %0 : vector<16xi32>
}
// CHECK-LABEL: @masked_load_and_store
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
return
}
// CHECK-LABEL: @gather_and_scatter
func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
// CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>