forked from OSchip/llvm-project
[mlir] [VectorOps] Add masked load/store operations to Vector dialect
The intrinsics were already supported and vector.transfer_read/write lowered direclty into these operations. By providing them as individual ops, however, clients can used them directly, and it opens up progressively lowering transfer operations at higher levels (rather than direct lowering to LLVM IR as done now). Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D85357
This commit is contained in:
parent
dd892a33e1
commit
39379916a7
|
@ -1150,6 +1150,102 @@ def Vector_TransferWriteOp :
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_MaskedLoadOp :
|
||||
Vector_Op<"maskedload">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$pass_thru)>,
|
||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||
|
||||
let summary = "loads elements from memory into a vector as defined by a mask vector";
|
||||
|
||||
let description = [{
|
||||
The masked load reads elements from memory into a 1-D vector as defined
|
||||
by a base and a 1-D mask vector. When the mask is set, the element is read
|
||||
from memory. Otherwise, the corresponding element is taken from a 1-D
|
||||
pass-through vector. Informally the semantics are:
|
||||
```
|
||||
result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
|
||||
result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
|
||||
etc.
|
||||
```
|
||||
The masked load can be used directly where applicable, or can be used
|
||||
during progressively lowering to bring other memory operations closer to
|
||||
hardware ISA support for a masked load. The semantics of the operation
|
||||
closely correspond to those of the `llvm.masked.load`
|
||||
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = vector.maskedload %base, %mask, %pass_thru
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getMaskVectorType() {
|
||||
return mask().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getPassThruVectorType() {
|
||||
return pass_thru().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultVectorType() {
|
||||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||
}
|
||||
|
||||
def Vector_MaskedStoreOp :
|
||||
Vector_Op<"maskedstore">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
|
||||
let summary = "stores elements from a vector into memory as defined by a mask vector";
|
||||
|
||||
let description = [{
|
||||
The masked store operation writes elements from a 1-D vector into memory
|
||||
as defined by a base and a 1-D mask vector. When the mask is set, the
|
||||
corresponding element from the vector is written to memory. Otherwise,
|
||||
no action is taken for the element. Informally the semantics are:
|
||||
```
|
||||
if (mask[0]) MEM[base+0] = value[0]
|
||||
if (mask[1]) MEM[base+1] = value[1]
|
||||
etc.
|
||||
```
|
||||
The masked store can be used directly where applicable, or can be used
|
||||
during progressively lowering to bring other memory operations closer to
|
||||
hardware ISA support for a masked store. The semantics of the operation
|
||||
closely correspond to those of the `llvm.masked.store`
|
||||
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
vector.maskedstore %base, %mask, %value
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getMaskVectorType() {
|
||||
return mask().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getValueVectorType() {
|
||||
return value().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
|
||||
"type($mask) `,` type($value) `into` type($base)";
|
||||
}
|
||||
|
||||
def Vector_GatherOp :
|
||||
Vector_Op<"gather">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%ld = vector.maskedload %base, %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
// Set up memory.
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%A = alloc(%c16) : memref<?xf32>
|
||||
scf.for %i = %c0 to %c16 step %c1 {
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
store %fi, %A[%i] : memref<?xf32>
|
||||
}
|
||||
|
||||
// Set up pass thru vector.
|
||||
%u = constant -7.0: f32
|
||||
%pass = vector.broadcast %u : f32 to vector<16xf32>
|
||||
|
||||
// Set up masks.
|
||||
%f = constant 0: i1
|
||||
%t = constant 1: i1
|
||||
%none = vector.constant_mask [0] : vector<16xi1>
|
||||
%all = vector.constant_mask [16] : vector<16xi1>
|
||||
%some = vector.constant_mask [8] : vector<16xi1>
|
||||
%0 = vector.insert %f, %some[0] : i1 into vector<16xi1>
|
||||
%1 = vector.insert %t, %0[13] : i1 into vector<16xi1>
|
||||
%2 = vector.insert %t, %1[14] : i1 into vector<16xi1>
|
||||
%other = vector.insert %t, %2[14] : i1 into vector<16xi1>
|
||||
|
||||
//
|
||||
// Masked load tests.
|
||||
//
|
||||
|
||||
%l1 = call @maskedload16(%A, %none, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %l1 : vector<16xf32>
|
||||
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
%l2 = call @maskedload16(%A, %all, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %l2 : vector<16xf32>
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
%l3 = call @maskedload16(%A, %some, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %l3 : vector<16xf32>
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
%l4 = call @maskedload16(%A, %other, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %l4 : vector<16xf32>
|
||||
// CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @maskedstore16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
vector.maskedstore %base, %mask, %value
|
||||
: vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @printmem16(%A: memref<?xf32>) {
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%z = constant 0.0: f32
|
||||
%m = vector.broadcast %z : f32 to vector<16xf32>
|
||||
%mem = scf.for %i = %c0 to %c16 step %c1
|
||||
iter_args(%m_iter = %m) -> (vector<16xf32>) {
|
||||
%c = load %A[%i] : memref<?xf32>
|
||||
%i32 = index_cast %i : index to i32
|
||||
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
|
||||
scf.yield %m_new : vector<16xf32>
|
||||
}
|
||||
vector.print %mem : vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
// Set up memory.
|
||||
%f0 = constant 0.0: f32
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%A = alloc(%c16) : memref<?xf32>
|
||||
scf.for %i = %c0 to %c16 step %c1 {
|
||||
store %f0, %A[%i] : memref<?xf32>
|
||||
}
|
||||
|
||||
// Set up value vector.
|
||||
%v = vector.broadcast %f0 : f32 to vector<16xf32>
|
||||
%val = scf.for %i = %c0 to %c16 step %c1
|
||||
iter_args(%v_iter = %v) -> (vector<16xf32>) {
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
%v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
|
||||
scf.yield %v_new : vector<16xf32>
|
||||
}
|
||||
|
||||
// Set up masks.
|
||||
%t = constant 1: i1
|
||||
%none = vector.constant_mask [0] : vector<16xi1>
|
||||
%some = vector.constant_mask [8] : vector<16xi1>
|
||||
%more = vector.insert %t, %some[13] : i1 into vector<16xi1>
|
||||
%all = vector.constant_mask [16] : vector<16xi1>
|
||||
|
||||
//
|
||||
// Masked store tests.
|
||||
//
|
||||
|
||||
vector.print %val : vector<16xf32>
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
call @maskedstore16(%A, %none, %val)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
call @maskedstore16(%A, %some, %val)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
call @maskedstore16(%A, %more, %val)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 13, 0, 0 )
|
||||
|
||||
call @maskedstore16(%A, %all, %val)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
return
|
||||
}
|
|
@ -163,6 +163,19 @@ LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns a bit-casted pointer given a memref base.
|
||||
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Type type,
|
||||
Value &ptr) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
|
||||
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
|
||||
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a memref base and an index
|
||||
// vector.
|
||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
|
@ -298,6 +311,72 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedload.
|
||||
class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorMaskedLoadOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto load = cast<vector::MaskedLoadOp>(op);
|
||||
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, load, align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter.convertType(load.getResultVectorType());
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
|
||||
vtype, ptr)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedstore.
|
||||
class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorMaskedStoreOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto store = cast<vector::MaskedStoreOp>(op);
|
||||
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(typeConverter, store, align)))
|
||||
return failure();
|
||||
|
||||
auto vtype = typeConverter.convertType(store.getValueVectorType());
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
|
||||
vtype, ptr)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
store, adaptor.value(), ptr, adaptor.mask(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.gather.
|
||||
class VectorGatherOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
|
@ -1342,6 +1421,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>,
|
||||
VectorTypeCastOpConversion,
|
||||
VectorMaskedLoadOpConversion,
|
||||
VectorMaskedStoreOpConversion,
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
|
|
|
@ -1855,6 +1855,41 @@ Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
|
|||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(MaskedLoadOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType passVType = op.getPassThruVectorType();
|
||||
VectorType resVType = op.getResultVectorType();
|
||||
|
||||
if (resVType.getElementType() != op.getMemRefType().getElementType())
|
||||
return op.emitOpError("base and result element type should match");
|
||||
|
||||
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||
return op.emitOpError("expected result dim to match mask dim");
|
||||
if (resVType != passVType)
|
||||
return op.emitOpError("expected pass_thru of same type as result type");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(MaskedStoreOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType valueVType = op.getValueVectorType();
|
||||
|
||||
if (valueVType.getElementType() != op.getMemRefType().getElementType())
|
||||
return op.emitOpError("base and value element type should match");
|
||||
|
||||
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||
return op.emitOpError("expected value dim to match mask dim");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -970,6 +970,26 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
|||
// CHECK-SAME: !llvm.vec<16 x float> into !llvm.vec<16 x float>
|
||||
// CHECK: llvm.return %[[T]] : !llvm.vec<16 x float>
|
||||
|
||||
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
|
||||
%0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @masked_load_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
|
||||
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x float>>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float>
|
||||
// CHECK: llvm.return %[[L]] : !llvm.vec<16 x float>
|
||||
|
||||
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
|
||||
vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @masked_store_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
|
||||
// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x float>>
|
||||
// CHECK: llvm.return
|
||||
|
||||
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
|
||||
%0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
|
||||
return %0 : vector<3xf32>
|
||||
|
|
|
@ -1180,6 +1180,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
|
|||
|
||||
// -----
|
||||
|
||||
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
|
||||
// expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
|
||||
vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
|
||||
vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
|
||||
// expected-error@+1 {{'vector.gather' op base and result element type should match}}
|
||||
%0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
|
||||
|
|
|
@ -369,6 +369,15 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
|
|||
return %0 : vector<16xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @masked_load_and_store
|
||||
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_and_scatter
|
||||
func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
|
||||
// CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
|
||||
|
|
Loading…
Reference in New Issue