[mlir][vector] generalized masked l/s and compressed l/s with indices

Adding the ability to index the base address brings these operations closer
to the transfer read and write semantics (with lowering advantages), ensures
more consistent use in vector MLIR code (easier to read), and reduces the
amount of code duplication to lower memrefs into base addresses considerably
(making codegen less error-prone).

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D94278
This commit is contained in:
Aart Bik 2021-01-08 10:26:57 -08:00
parent 1ba5ea67a3
commit a57def30f5
13 changed files with 323 additions and 271 deletions

View File

@ -1317,6 +1317,7 @@ def Vector_TransferWriteOp :
def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins AnyMemRef:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@ -1325,12 +1326,12 @@ def Vector_MaskedLoadOp :
let description = [{
The masked load reads elements from memory into a 1-D vector as defined
by a base and a 1-D mask vector. When the mask is set, the element is read
from memory. Otherwise, the corresponding element is taken from a 1-D
pass-through vector. Informally the semantics are:
by a base with indices and a 1-D mask vector. When the mask is set, the
element is read from memory. Otherwise, the corresponding element is taken
from a 1-D pass-through vector. Informally the semantics are:
```
result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
result[0] := mask[0] ? base[i+0] : pass_thru[0]
result[1] := mask[1] ? base[i+1] : pass_thru[1]
etc.
```
The masked load can be used directly where applicable, or can be used
@ -1342,7 +1343,7 @@ def Vector_MaskedLoadOp :
Example:
```mlir
%0 = vector.maskedload %base, %mask, %pass_thru
%0 = vector.maskedload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
@ -1360,7 +1361,7 @@ def Vector_MaskedLoadOp :
return result().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
@ -1368,6 +1369,7 @@ def Vector_MaskedLoadOp :
def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Arguments<(ins AnyMemRef:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
@ -1375,12 +1377,12 @@ def Vector_MaskedStoreOp :
let description = [{
The masked store operation writes elements from a 1-D vector into memory
as defined by a base and a 1-D mask vector. When the mask is set, the
corresponding element from the vector is written to memory. Otherwise,
as defined by a base with indices and a 1-D mask vector. When the mask is
set, the corresponding element from the vector is written to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
```
if (mask[0]) MEM[base+0] = value[0]
if (mask[1]) MEM[base+1] = value[1]
if (mask[0]) base[i+0] = value[0]
if (mask[1]) base[i+1] = value[1]
etc.
```
The masked store can be used directly where applicable, or can be used
@ -1392,7 +1394,7 @@ def Vector_MaskedStoreOp :
Example:
```mlir
vector.maskedstore %base, %mask, %value
vector.maskedstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
@ -1407,8 +1409,8 @@ def Vector_MaskedStoreOp :
return value().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
"type($mask) `,` type($value) `into` type($base)";
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}
@ -1430,8 +1432,8 @@ def Vector_GatherOp :
semantics are:
```
if (!defined(pass_thru)) pass_thru = [undef, .., undef]
result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
etc.
```
The vector dialect leaves out-of-bounds behavior undefined.
@ -1487,8 +1489,8 @@ def Vector_ScatterOp :
bit in a 1-D mask vector is set. Otherwise, no action is taken for that
element. Informally the semantics are:
```
if (mask[0]) MEM[base + index[0]] = value[0]
if (mask[1]) MEM[base + index[1]] = value[1]
if (mask[0]) base[index[0]] = value[0]
if (mask[1]) base[index[1]] = value[1]
etc.
```
The vector dialect leaves out-of-bounds and repeated index behavior
@ -1531,6 +1533,7 @@ def Vector_ScatterOp :
def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins AnyMemRef:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@ -1539,13 +1542,13 @@ def Vector_ExpandLoadOp :
let description = [{
The expand load reads elements from memory into a 1-D vector as defined
by a base and a 1-D mask vector. When the mask is set, the next element
is read from memory. Otherwise, the corresponding element is taken from
a 1-D pass-through vector. Informally the semantics are:
by a base with indices and a 1-D mask vector. When the mask is set, the
next element is read from memory. Otherwise, the corresponding element
is taken from a 1-D pass-through vector. Informally the semantics are:
```
index = base
result[0] := mask[0] ? MEM[index++] : pass_thru[0]
result[1] := mask[1] ? MEM[index++] : pass_thru[1]
index = i
result[0] := mask[0] ? base[index++] : pass_thru[0]
result[1] := mask[1] ? base[index++] : pass_thru[1]
etc.
```
Note that the index increment is done conditionally.
@ -1559,7 +1562,7 @@ def Vector_ExpandLoadOp :
Example:
```mlir
%0 = vector.expandload %base, %mask, %pass_thru
%0 = vector.expandload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
@ -1577,7 +1580,7 @@ def Vector_ExpandLoadOp :
return result().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
@ -1585,6 +1588,7 @@ def Vector_ExpandLoadOp :
def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins AnyMemRef:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
@ -1592,13 +1596,13 @@ def Vector_CompressStoreOp :
let description = [{
The compress store operation writes elements from a 1-D vector into memory
as defined by a base and a 1-D mask vector. When the mask is set, the
corresponding element from the vector is written next to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
as defined by a base with indices and a 1-D mask vector. When the mask is
set, the corresponding element from the vector is written next to memory.
Otherwise, no action is taken for the element. Informally the semantics are:
```
index = base
if (mask[0]) MEM[index++] = value[0]
if (mask[1]) MEM[index++] = value[1]
index = i
if (mask[0]) base[index++] = value[0]
if (mask[1]) base[index++] = value[1]
etc.
```
Note that the index increment is done conditionally.
@ -1612,7 +1616,7 @@ def Vector_CompressStoreOp :
Example:
```mlir
vector.compressstore %base, %mask, %value
vector.compressstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
@ -1627,7 +1631,7 @@ def Vector_CompressStoreOp :
return value().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}

View File

@ -5,7 +5,16 @@
func @compress16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
vector.compressstore %base, %mask, %value
%c0 = constant 0: index
vector.compressstore %base[%c0], %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
func @compress16_at8(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c8 = constant 8: index
vector.compressstore %base[%c8], %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@ -86,5 +95,10 @@ func @entry() {
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @compress16_at8(%A, %some1, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 0, 1, 2, 3, 12, 13, 14, 15 )
return
}

View File

@ -5,8 +5,18 @@
func @expand16(%base: memref<?xf32>,
%mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%e = vector.expandload %base, %mask, %pass_thru
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0: index
%e = vector.expandload %base[%c0], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %e : vector<16xf32>
}
func @expand16_at8(%base: memref<?xf32>,
%mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%c8 = constant 8: index
%e = vector.expandload %base[%c8], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %e : vector<16xf32>
}
@ -78,5 +88,10 @@ func @entry() {
vector.print %e6 : vector<16xf32>
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
%e7 = call @expand16_at8(%A, %some1, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e7 : vector<16xf32>
// CHECK-NEXT: ( 8, 9, 10, 11, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
return
}

View File

@ -5,7 +5,16 @@
func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%ld = vector.maskedload %base, %mask, %pass_thru
%c0 = constant 0: index
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
func @maskedload16_at8(%base: memref<?xf32>, %mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%c8 = constant 8: index
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@ -61,6 +70,11 @@ func @entry() {
vector.print %l4 : vector<16xf32>
// CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
%l5 = call @maskedload16_at8(%A, %some, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %l5 : vector<16xf32>
// CHECK: ( 8, 9, 10, 11, 12, 13, 14, 15, -7, -7, -7, -7, -7, -7, -7, -7 )
return
}

View File

@ -5,8 +5,17 @@
func @maskedstore16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
vector.maskedstore %base, %mask, %value
: vector<16xi1>, vector<16xf32> into memref<?xf32>
%c0 = constant 0: index
vector.maskedstore %base[%c0], %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
func @maskedstore16_at8(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c8 = constant 8: index
vector.maskedstore %base[%c8], %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@ -85,5 +94,10 @@ func @entry() {
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @maskedstore16_at8(%A, %some, %val)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 )
return
}

View File

@ -173,33 +173,7 @@ static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
// Helper that returns a pointer given a memref base.
static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
Location loc, Value memref,
MemRefType memRefType, Value &ptr) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
// Helper that returns a bit-casted pointer given a memref base.
static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
Location loc, Value memref,
MemRefType memRefType, Type type, Value &ptr) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = LLVM::LLVMPointerType::get(type);
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
// Helper that returns vector of pointers given a memref base and an index
// vector.
// Helper that returns vector of pointers given a memref base with index vector.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
Location loc, Value memref, Value indices,
MemRefType memRefType, VectorType vType,
@ -213,6 +187,18 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
return success();
}
// Casts a strided element pointer to a vector pointer. The vector pointer
// would always be on address space 0, therefore addrspacecast shall be
// used when source/dst memrefs are not on address space 0.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, MemRefType memRefType, Type vt) {
auto pType =
LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
if (memRefType.getMemorySpace() == 0)
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
}
static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
@ -343,18 +329,18 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
MemRefType memRefType = load.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(load.getResultVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
vtype, ptr)))
return failure();
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
@ -374,18 +360,18 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
MemRefType memRefType = store.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(store.getValueVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
vtype, ptr)))
return failure();
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
@ -473,16 +459,15 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
MemRefType memRefType = expand.getMemRefType();
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
ptr)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(expand.getResultVectorType());
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
return success();
}
};
@ -498,11 +483,11 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
MemRefType memRefType = compress.getMemRefType();
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
compress.getMemRefType(), ptr)))
return failure();
// Resolve address.
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.value(), ptr, adaptor.mask());
@ -1223,21 +1208,11 @@ public:
}
// 1. Get the source/dst address as an LLVM vector pointer.
// The vector pointer would always be on address space 0, therefore
// addrspacecast shall be used when source/dst memrefs are not on
// address space 0.
// TODO: support alignment when possible.
VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
auto vecTy = toLLVMTy(xferOp.getVectorType())
.template cast<LLVM::LLVMFixedVectorType>();
Value vectorDataPtr;
if (memRefType.getMemorySpace() == 0)
vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
else
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
Value vectorDataPtr =
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(rewriter,
@ -1251,7 +1226,7 @@ public:
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = vecTy.getNumElements();
unsigned vecWidth = vtp.getNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);

View File

@ -76,20 +76,6 @@ static MaskFormat get1DMaskFormat(Value mask) {
return MaskFormat::Unknown;
}
/// Helper method to cast a 1-D memref<10xf32> "base" into a
/// memref<vector<10xf32>> in the output parameter "newBase",
/// using the 'element' vector type "vt". Returns true on success.
static bool castedToMemRef(Location loc, Value base, MemRefType mt,
VectorType vt, PatternRewriter &rewriter,
Value &newBase) {
// The vector.type_cast operation does not accept unknown memref<?xf32>.
// TODO: generalize the cast and accept this case too
if (!mt.hasStaticShape())
return false;
newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
return true;
}
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
@ -2380,13 +2366,10 @@ public:
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
Value newBase;
switch (get1DMaskFormat(load.mask())) {
case MaskFormat::AllTrue:
if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
load.getResultVectorType(), rewriter, newBase))
return failure();
rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
load, load.getType(), load.base(), load.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(load, load.pass_thru());
@ -2426,13 +2409,10 @@ public:
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
Value newBase;
switch (get1DMaskFormat(store.mask())) {
case MaskFormat::AllTrue:
if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
store.getValueVectorType(), rewriter, newBase))
return failure();
rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
store, store.value(), store.base(), store.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(store);
@ -2568,14 +2548,10 @@ public:
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
Value newBase;
switch (get1DMaskFormat(expand.mask())) {
case MaskFormat::AllTrue:
if (!castedToMemRef(expand.getLoc(), expand.base(),
expand.getMemRefType(), expand.getResultVectorType(),
rewriter, newBase))
return failure();
rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
expand, expand.getType(), expand.base(), expand.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(expand, expand.pass_thru());
@ -2615,14 +2591,11 @@ public:
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
Value newBase;
switch (get1DMaskFormat(compress.mask())) {
case MaskFormat::AllTrue:
if (!castedToMemRef(compress.getLoc(), compress.base(),
compress.getMemRefType(),
compress.getValueVectorType(), rewriter, newBase))
return failure();
rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
compress, compress.value(), compress.base(), compress.indices(),
false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(compress);

View File

@ -1070,23 +1070,29 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
// CHECK: llvm.return %[[T]] : !llvm.vec<16 x f32>
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%c0 = constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @masked_load_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[B]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
// CHECK: llvm.return %[[L]] : !llvm.vec<16 x f32>
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
%c0 = constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: func @masked_store_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
// CHECK: llvm.intr.masked.store %{{.*}}, %[[B]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
// CHECK: llvm.return
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
@ -1110,21 +1116,25 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
// CHECK: llvm.return
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
%0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
%c0 = constant 0: index
%0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
return %0 : vector<11xf32>
}
// CHECK-LABEL: func @expand_load_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<f32>, !llvm.vec<11 x i1>, !llvm.vec<11 x f32>) -> !llvm.vec<11 x f32>
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x f32>
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
%c0 = constant 0: index
vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
return
}
// CHECK-LABEL: func @compress_store_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x f32>, !llvm.ptr<f32>, !llvm.vec<11 x i1>) -> ()
// CHECK: llvm.return

View File

@ -1199,36 +1199,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
// -----
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
}
// -----
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
// -----
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
}
// -----
@ -1297,36 +1302,41 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
// -----
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
}
// -----
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
// -----
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
}
// -----

View File

@ -452,10 +452,11 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
// CHECK-LABEL: @masked_load_and_store
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
%c0 = constant 0 : index
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.maskedstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@ -472,10 +473,11 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
// CHECK-LABEL: @expand_and_compress
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
%c0 = constant 0 : index
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}

View File

@ -1,82 +1,93 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
//
// TODO: optimize this one too!
//
// CHECK-LABEL: func @maskedload0(
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask
// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
// CHECK-LABEL: func @maskedload0(
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.maskedload %base, %mask, %pass_thru
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @maskedload1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
// CHECK-NEXT: return %[[T1]] : vector<16xf32>
// CHECK-LABEL: func @maskedload1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.maskedload %base, %mask, %pass_thru
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @maskedload2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
// CHECK-LABEL: func @maskedload2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.maskedload %base, %mask, %pass_thru
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @maskedstore1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
// CHECK-NEXT: return
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-LABEL: func @maskedload3(
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 8 : index
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c8 = constant 8 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.maskedstore %base, %mask, %value
: vector<16xi1>, vector<16xf32> into memref<16xf32>
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @maskedstore1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
// CHECK-NEXT: return
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: func @maskedstore2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: return
// CHECK-LABEL: func @maskedstore2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
vector.maskedstore %base, %mask, %value
: vector<16xi1>, vector<16xf32> into memref<16xf32>
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: func @gather1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
// CHECK-NEXT: return %1 : vector<16xf32>
// CHECK-LABEL: func @gather1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.gather %base, %indices, %mask, %pass_thru
@ -84,12 +95,11 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @gather2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
// CHECK-LABEL: func @gather2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.gather %base, %indices, %mask, %pass_thru
@ -97,14 +107,13 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @scatter1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
// CHECK-NEXT: return
// CHECK-LABEL: func @scatter1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
// CHECK-NEXT: return
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%mask = vector.constant_mask [16] : vector<16xi1>
vector.scatter %base, %indices, %mask, %value
@ -112,12 +121,11 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
return
}
// CHECK-LABEL: func @scatter2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
// CHECK-NEXT: return
// CHECK-LABEL: func @scatter2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
@ -126,52 +134,53 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
return
}
// CHECK-LABEL: func @expand1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
// CHECK-NEXT: return %[[T1]] : vector<16xf32>
// CHECK-LABEL: func @expand1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.expandload %base, %mask, %pass_thru
%ld = vector.expandload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @expand2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
// CHECK-LABEL: func @expand2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.expandload %base, %mask, %pass_thru
%ld = vector.expandload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
// CHECK-LABEL: func @compress1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
// CHECK-NEXT: return
// CHECK-LABEL: func @compress1(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
// CHECK-NEXT: return
func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: func @compress2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
// CHECK-NEXT: return
// CHECK-LABEL: func @compress2(
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>

View File

@ -24,13 +24,25 @@ namespace {
struct TestVectorToVectorConversion
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
TestVectorToVectorConversion() = default;
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
llvm::cl::init(false)};
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
if (unroll) {
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));