forked from OSchip/llvm-project
[mlir][vector] generalized masked l/s and compressed l/s with indices
Adding the ability to index the base address brings these operations closer to the transfer read and write semantics (with lowering advantages), ensures more consistent use in vector MLIR code (easier to read), and reduces the amount of code duplication to lower memrefs into base addresses considerably (making codegen less error-prone). Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D94278
This commit is contained in:
parent
1ba5ea67a3
commit
a57def30f5
|
@ -1317,6 +1317,7 @@ def Vector_TransferWriteOp :
|
|||
def Vector_MaskedLoadOp :
|
||||
Vector_Op<"maskedload">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$pass_thru)>,
|
||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||
|
@ -1325,12 +1326,12 @@ def Vector_MaskedLoadOp :
|
|||
|
||||
let description = [{
|
||||
The masked load reads elements from memory into a 1-D vector as defined
|
||||
by a base and a 1-D mask vector. When the mask is set, the element is read
|
||||
from memory. Otherwise, the corresponding element is taken from a 1-D
|
||||
pass-through vector. Informally the semantics are:
|
||||
by a base with indices and a 1-D mask vector. When the mask is set, the
|
||||
element is read from memory. Otherwise, the corresponding element is taken
|
||||
from a 1-D pass-through vector. Informally the semantics are:
|
||||
```
|
||||
result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
|
||||
result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
|
||||
result[0] := mask[0] ? base[i+0] : pass_thru[0]
|
||||
result[1] := mask[1] ? base[i+1] : pass_thru[1]
|
||||
etc.
|
||||
```
|
||||
The masked load can be used directly where applicable, or can be used
|
||||
|
@ -1342,7 +1343,7 @@ def Vector_MaskedLoadOp :
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = vector.maskedload %base, %mask, %pass_thru
|
||||
%0 = vector.maskedload %base[%i], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
@ -1360,7 +1361,7 @@ def Vector_MaskedLoadOp :
|
|||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -1368,6 +1369,7 @@ def Vector_MaskedLoadOp :
|
|||
def Vector_MaskedStoreOp :
|
||||
Vector_Op<"maskedstore">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
|
||||
|
@ -1375,12 +1377,12 @@ def Vector_MaskedStoreOp :
|
|||
|
||||
let description = [{
|
||||
The masked store operation writes elements from a 1-D vector into memory
|
||||
as defined by a base and a 1-D mask vector. When the mask is set, the
|
||||
corresponding element from the vector is written to memory. Otherwise,
|
||||
as defined by a base with indices and a 1-D mask vector. When the mask is
|
||||
set, the corresponding element from the vector is written to memory. Otherwise,
|
||||
no action is taken for the element. Informally the semantics are:
|
||||
```
|
||||
if (mask[0]) MEM[base+0] = value[0]
|
||||
if (mask[1]) MEM[base+1] = value[1]
|
||||
if (mask[0]) base[i+0] = value[0]
|
||||
if (mask[1]) base[i+1] = value[1]
|
||||
etc.
|
||||
```
|
||||
The masked store can be used directly where applicable, or can be used
|
||||
|
@ -1392,7 +1394,7 @@ def Vector_MaskedStoreOp :
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
vector.maskedstore %base, %mask, %value
|
||||
vector.maskedstore %base[%i], %mask, %value
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
@ -1407,8 +1409,8 @@ def Vector_MaskedStoreOp :
|
|||
return value().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
|
||||
"type($mask) `,` type($value) `into` type($base)";
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($value)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -1430,8 +1432,8 @@ def Vector_GatherOp :
|
|||
semantics are:
|
||||
```
|
||||
if (!defined(pass_thru)) pass_thru = [undef, .., undef]
|
||||
result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
|
||||
result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
|
||||
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
|
||||
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
|
||||
etc.
|
||||
```
|
||||
The vector dialect leaves out-of-bounds behavior undefined.
|
||||
|
@ -1487,8 +1489,8 @@ def Vector_ScatterOp :
|
|||
bit in a 1-D mask vector is set. Otherwise, no action is taken for that
|
||||
element. Informally the semantics are:
|
||||
```
|
||||
if (mask[0]) MEM[base + index[0]] = value[0]
|
||||
if (mask[1]) MEM[base + index[1]] = value[1]
|
||||
if (mask[0]) base[index[0]] = value[0]
|
||||
if (mask[1]) base[index[1]] = value[1]
|
||||
etc.
|
||||
```
|
||||
The vector dialect leaves out-of-bounds and repeated index behavior
|
||||
|
@ -1531,6 +1533,7 @@ def Vector_ScatterOp :
|
|||
def Vector_ExpandLoadOp :
|
||||
Vector_Op<"expandload">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$pass_thru)>,
|
||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||
|
@ -1539,13 +1542,13 @@ def Vector_ExpandLoadOp :
|
|||
|
||||
let description = [{
|
||||
The expand load reads elements from memory into a 1-D vector as defined
|
||||
by a base and a 1-D mask vector. When the mask is set, the next element
|
||||
is read from memory. Otherwise, the corresponding element is taken from
|
||||
a 1-D pass-through vector. Informally the semantics are:
|
||||
by a base with indices and a 1-D mask vector. When the mask is set, the
|
||||
next element is read from memory. Otherwise, the corresponding element
|
||||
is taken from a 1-D pass-through vector. Informally the semantics are:
|
||||
```
|
||||
index = base
|
||||
result[0] := mask[0] ? MEM[index++] : pass_thru[0]
|
||||
result[1] := mask[1] ? MEM[index++] : pass_thru[1]
|
||||
index = i
|
||||
result[0] := mask[0] ? base[index++] : pass_thru[0]
|
||||
result[1] := mask[1] ? base[index++] : pass_thru[1]
|
||||
etc.
|
||||
```
|
||||
Note that the index increment is done conditionally.
|
||||
|
@ -1559,7 +1562,7 @@ def Vector_ExpandLoadOp :
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = vector.expandload %base, %mask, %pass_thru
|
||||
%0 = vector.expandload %base[%i], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
@ -1577,7 +1580,7 @@ def Vector_ExpandLoadOp :
|
|||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -1585,6 +1588,7 @@ def Vector_ExpandLoadOp :
|
|||
def Vector_CompressStoreOp :
|
||||
Vector_Op<"compressstore">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
|
||||
|
@ -1592,13 +1596,13 @@ def Vector_CompressStoreOp :
|
|||
|
||||
let description = [{
|
||||
The compress store operation writes elements from a 1-D vector into memory
|
||||
as defined by a base and a 1-D mask vector. When the mask is set, the
|
||||
corresponding element from the vector is written next to memory. Otherwise,
|
||||
no action is taken for the element. Informally the semantics are:
|
||||
as defined by a base with indices and a 1-D mask vector. When the mask is
|
||||
set, the corresponding element from the vector is written next to memory.
|
||||
Otherwise, no action is taken for the element. Informally the semantics are:
|
||||
```
|
||||
index = base
|
||||
if (mask[0]) MEM[index++] = value[0]
|
||||
if (mask[1]) MEM[index++] = value[1]
|
||||
index = i
|
||||
if (mask[0]) base[index++] = value[0]
|
||||
if (mask[1]) base[index++] = value[1]
|
||||
etc.
|
||||
```
|
||||
Note that the index increment is done conditionally.
|
||||
|
@ -1612,7 +1616,7 @@ def Vector_CompressStoreOp :
|
|||
Example:
|
||||
|
||||
```mlir
|
||||
vector.compressstore %base, %mask, %value
|
||||
vector.compressstore %base[%i], %mask, %value
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
@ -1627,7 +1631,7 @@ def Vector_CompressStoreOp :
|
|||
return value().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($value)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,16 @@
|
|||
|
||||
func @compress16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
vector.compressstore %base, %mask, %value
|
||||
%c0 = constant 0: index
|
||||
vector.compressstore %base[%c0], %mask, %value
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @compress16_at8(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
%c8 = constant 8: index
|
||||
vector.compressstore %base[%c8], %mask, %value
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
@ -86,5 +95,10 @@ func @entry() {
|
|||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @compress16_at8(%A, %some1, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 0, 1, 2, 3, 12, 13, 14, 15 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,8 +5,18 @@
|
|||
|
||||
func @expand16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%e = vector.expandload %base, %mask, %pass_thru
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0: index
|
||||
%e = vector.expandload %base[%c0], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %e : vector<16xf32>
|
||||
}
|
||||
|
||||
func @expand16_at8(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c8 = constant 8: index
|
||||
%e = vector.expandload %base[%c8], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %e : vector<16xf32>
|
||||
}
|
||||
|
@ -78,5 +88,10 @@ func @entry() {
|
|||
vector.print %e6 : vector<16xf32>
|
||||
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
|
||||
|
||||
%e7 = call @expand16_at8(%A, %some1, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e7 : vector<16xf32>
|
||||
// CHECK-NEXT: ( 8, 9, 10, 11, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,7 +5,16 @@
|
|||
|
||||
func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%ld = vector.maskedload %base, %mask, %pass_thru
|
||||
%c0 = constant 0: index
|
||||
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
func @maskedload16_at8(%base: memref<?xf32>, %mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c8 = constant 8: index
|
||||
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
@ -61,6 +70,11 @@ func @entry() {
|
|||
vector.print %l4 : vector<16xf32>
|
||||
// CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
|
||||
|
||||
%l5 = call @maskedload16_at8(%A, %some, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %l5 : vector<16xf32>
|
||||
// CHECK: ( 8, 9, 10, 11, 12, 13, 14, 15, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,17 @@
|
|||
|
||||
func @maskedstore16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
vector.maskedstore %base, %mask, %value
|
||||
: vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
%c0 = constant 0: index
|
||||
vector.maskedstore %base[%c0], %mask, %value
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @maskedstore16_at8(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
%c8 = constant 8: index
|
||||
vector.maskedstore %base[%c8], %mask, %value
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -85,5 +94,10 @@ func @entry() {
|
|||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @maskedstore16_at8(%A, %some, %val)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -173,33 +173,7 @@ static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns a pointer given a memref base.
|
||||
static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref,
|
||||
MemRefType memRefType, Value &ptr) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementPtrType();
|
||||
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns a bit-casted pointer given a memref base.
|
||||
static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref,
|
||||
MemRefType memRefType, Type type, Value &ptr) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = LLVM::LLVMPointerType::get(type);
|
||||
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
|
||||
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a memref base and an index
|
||||
// vector.
|
||||
// Helper that returns vector of pointers given a memref base with index vector.
|
||||
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref, Value indices,
|
||||
MemRefType memRefType, VectorType vType,
|
||||
|
@ -213,6 +187,18 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Casts a strided element pointer to a vector pointer. The vector pointer
|
||||
// would always be on address space 0, therefore addrspacecast shall be
|
||||
// used when source/dst memrefs are not on address space 0.
|
||||
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, MemRefType memRefType, Type vt) {
|
||||
auto pType =
|
||||
LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
|
||||
if (memRefType.getMemorySpace() == 0)
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
|
||||
return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
|
@ -343,18 +329,18 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = load->getLoc();
|
||||
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
|
||||
MemRefType memRefType = load.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
|
||||
align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(load.getResultVectorType());
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
|
||||
vtype, ptr)))
|
||||
return failure();
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
|
||||
|
@ -374,18 +360,18 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = store->getLoc();
|
||||
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
|
||||
MemRefType memRefType = store.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
|
||||
align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(store.getValueVectorType());
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
|
||||
vtype, ptr)))
|
||||
return failure();
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
store, adaptor.value(), ptr, adaptor.mask(),
|
||||
|
@ -473,16 +459,15 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = expand->getLoc();
|
||||
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||
MemRefType memRefType = expand.getMemRefType();
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
|
||||
ptr)))
|
||||
return failure();
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(expand.getResultVectorType());
|
||||
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
|
||||
auto vType = expand.getResultVectorType();
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||
expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
|
||||
adaptor.pass_thru());
|
||||
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -498,11 +483,11 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = compress->getLoc();
|
||||
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||
MemRefType memRefType = compress.getMemRefType();
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
|
||||
compress.getMemRefType(), ptr)))
|
||||
return failure();
|
||||
// Resolve address.
|
||||
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||
compress, adaptor.value(), ptr, adaptor.mask());
|
||||
|
@ -1223,21 +1208,11 @@ public:
|
|||
}
|
||||
|
||||
// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
// The vector pointer would always be on address space 0, therefore
|
||||
// addrspacecast shall be used when source/dst memrefs are not on
|
||||
// address space 0.
|
||||
// TODO: support alignment when possible.
|
||||
VectorType vtp = xferOp.getVectorType();
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
auto vecTy = toLLVMTy(xferOp.getVectorType())
|
||||
.template cast<LLVM::LLVMFixedVectorType>();
|
||||
Value vectorDataPtr;
|
||||
if (memRefType.getMemorySpace() == 0)
|
||||
vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
|
||||
else
|
||||
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
|
||||
loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
|
||||
Value vectorDataPtr =
|
||||
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
|
||||
|
||||
if (!xferOp.isMaskedDim(0))
|
||||
return replaceTransferOpWithLoadOrStore(rewriter,
|
||||
|
@ -1251,7 +1226,7 @@ public:
|
|||
//
|
||||
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
|
||||
// dimensions here.
|
||||
unsigned vecWidth = vecTy.getNumElements();
|
||||
unsigned vecWidth = vtp.getNumElements();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
|
||||
|
|
|
@ -76,20 +76,6 @@ static MaskFormat get1DMaskFormat(Value mask) {
|
|||
return MaskFormat::Unknown;
|
||||
}
|
||||
|
||||
/// Helper method to cast a 1-D memref<10xf32> "base" into a
|
||||
/// memref<vector<10xf32>> in the output parameter "newBase",
|
||||
/// using the 'element' vector type "vt". Returns true on success.
|
||||
static bool castedToMemRef(Location loc, Value base, MemRefType mt,
|
||||
VectorType vt, PatternRewriter &rewriter,
|
||||
Value &newBase) {
|
||||
// The vector.type_cast operation does not accept unknown memref<?xf32>.
|
||||
// TODO: generalize the cast and accept this case too
|
||||
if (!mt.hasStaticShape())
|
||||
return false;
|
||||
newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2380,13 +2366,10 @@ public:
|
|||
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(MaskedLoadOp load,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value newBase;
|
||||
switch (get1DMaskFormat(load.mask())) {
|
||||
case MaskFormat::AllTrue:
|
||||
if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
|
||||
load.getResultVectorType(), rewriter, newBase))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
load, load.getType(), load.base(), load.indices(), false);
|
||||
return success();
|
||||
case MaskFormat::AllFalse:
|
||||
rewriter.replaceOp(load, load.pass_thru());
|
||||
|
@ -2426,13 +2409,10 @@ public:
|
|||
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(MaskedStoreOp store,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value newBase;
|
||||
switch (get1DMaskFormat(store.mask())) {
|
||||
case MaskFormat::AllTrue:
|
||||
if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
|
||||
store.getValueVectorType(), rewriter, newBase))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
store, store.value(), store.base(), store.indices(), false);
|
||||
return success();
|
||||
case MaskFormat::AllFalse:
|
||||
rewriter.eraseOp(store);
|
||||
|
@ -2568,14 +2548,10 @@ public:
|
|||
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ExpandLoadOp expand,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value newBase;
|
||||
switch (get1DMaskFormat(expand.mask())) {
|
||||
case MaskFormat::AllTrue:
|
||||
if (!castedToMemRef(expand.getLoc(), expand.base(),
|
||||
expand.getMemRefType(), expand.getResultVectorType(),
|
||||
rewriter, newBase))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
expand, expand.getType(), expand.base(), expand.indices(), false);
|
||||
return success();
|
||||
case MaskFormat::AllFalse:
|
||||
rewriter.replaceOp(expand, expand.pass_thru());
|
||||
|
@ -2615,14 +2591,11 @@ public:
|
|||
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(CompressStoreOp compress,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value newBase;
|
||||
switch (get1DMaskFormat(compress.mask())) {
|
||||
case MaskFormat::AllTrue:
|
||||
if (!castedToMemRef(compress.getLoc(), compress.base(),
|
||||
compress.getMemRefType(),
|
||||
compress.getValueVectorType(), rewriter, newBase))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
compress, compress.value(), compress.base(), compress.indices(),
|
||||
false);
|
||||
return success();
|
||||
case MaskFormat::AllFalse:
|
||||
rewriter.eraseOp(compress);
|
||||
|
|
|
@ -1070,23 +1070,29 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
|||
// CHECK: llvm.return %[[T]] : !llvm.vec<16 x f32>
|
||||
|
||||
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
|
||||
%0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%c0 = constant 0: index
|
||||
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @masked_load_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[B]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
|
||||
// CHECK: llvm.return %[[L]] : !llvm.vec<16 x f32>
|
||||
|
||||
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
|
||||
vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
%c0 = constant 0: index
|
||||
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @masked_store_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: llvm.intr.masked.store %{{.*}}, %[[B]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
|
||||
// CHECK: llvm.return
|
||||
|
||||
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
|
||||
|
@ -1110,21 +1116,25 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
|
|||
// CHECK: llvm.return
|
||||
|
||||
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
|
||||
%0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
|
||||
%c0 = constant 0: index
|
||||
%0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
|
||||
return %0 : vector<11xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expand_load_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<f32>, !llvm.vec<11 x i1>, !llvm.vec<11 x f32>) -> !llvm.vec<11 x f32>
|
||||
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x f32>
|
||||
|
||||
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
|
||||
vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
|
||||
%c0 = constant 0: index
|
||||
vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compress_store_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x f32>, !llvm.ptr<f32>, !llvm.vec<11 x i1>) -> ()
|
||||
// CHECK: llvm.return
|
||||
|
|
|
@ -1199,36 +1199,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
|
|||
// -----
|
||||
|
||||
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
|
||||
%0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
|
||||
vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
|
||||
vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -1297,36 +1302,41 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
|
|||
// -----
|
||||
|
||||
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
|
||||
vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
|
||||
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
|
||||
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -452,10 +452,11 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
|
|||
|
||||
// CHECK-LABEL: @masked_load_and_store
|
||||
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.maskedstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -472,10 +473,11 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
|
|||
|
||||
// CHECK-LABEL: @expand_and_compress
|
||||
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,82 +1,93 @@
|
|||
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
||||
|
||||
//
|
||||
// TODO: optimize this one too!
|
||||
//
|
||||
// CHECK-LABEL: func @maskedload0(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask
|
||||
// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK-NEXT: return %[[T]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @maskedload0(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-DAG: %[[C:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
|
||||
// CHECK-NEXT: return %[[T]] : vector<16xf32>
|
||||
func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
%ld = vector.maskedload %base, %mask, %pass_thru
|
||||
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maskedload1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
|
||||
// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
|
||||
// CHECK-NEXT: return %[[T1]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @maskedload1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-DAG: %[[C:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
|
||||
// CHECK-NEXT: return %[[T]] : vector<16xf32>
|
||||
func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
%ld = vector.maskedload %base, %mask, %pass_thru
|
||||
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
|
||||
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maskedload2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @maskedload2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
|
||||
func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
%ld = vector.maskedload %base, %mask, %pass_thru
|
||||
%ld = vector.maskedload %base[%c0], %mask, %pass_thru
|
||||
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maskedstore1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
|
||||
// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
|
||||
// CHECK-LABEL: func @maskedload3(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-DAG: %[[C:.*]] = constant 8 : index
|
||||
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
|
||||
// CHECK-NEXT: return %[[T]] : vector<16xf32>
|
||||
func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c8 = constant 8 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
vector.maskedstore %base, %mask, %value
|
||||
: vector<16xi1>, vector<16xf32> into memref<16xf32>
|
||||
%ld = vector.maskedload %base[%c8], %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maskedstore1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
|
||||
// CHECK-NEXT: return
|
||||
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maskedstore2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @maskedstore2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: return
|
||||
func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
vector.maskedstore %base, %mask, %value
|
||||
: vector<16xi1>, vector<16xf32> into memref<16xf32>
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @gather1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
|
||||
// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
|
||||
// CHECK-NEXT: return %1 : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @gather1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
|
||||
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
|
||||
// CHECK-NEXT: return %[[G]] : vector<16xf32>
|
||||
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
%ld = vector.gather %base, %indices, %mask, %pass_thru
|
||||
|
@ -84,12 +95,11 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
|
|||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @gather2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @gather2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
|
||||
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
%ld = vector.gather %base, %indices, %mask, %pass_thru
|
||||
|
@ -97,14 +107,13 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
|
|||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @scatter1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
|
||||
// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @scatter1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
|
||||
// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
|
||||
// CHECK-NEXT: return
|
||||
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
vector.scatter %base, %indices, %mask, %value
|
||||
|
@ -112,12 +121,11 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @scatter2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @scatter2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
|
||||
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: return
|
||||
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
|
||||
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
|
@ -126,52 +134,53 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expand1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
|
||||
// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
|
||||
// CHECK-NEXT: return %[[T1]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @expand1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-DAG: %[[C:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
|
||||
// CHECK-NEXT: return %[[T]] : vector<16xf32>
|
||||
func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
%ld = vector.expandload %base, %mask, %pass_thru
|
||||
%ld = vector.expandload %base[%c0], %mask, %pass_thru
|
||||
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expand2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
|
||||
|
||||
// CHECK-LABEL: func @expand2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK-NEXT: return %[[A1]] : vector<16xf32>
|
||||
func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
%ld = vector.expandload %base, %mask, %pass_thru
|
||||
%ld = vector.expandload %base[%c0], %mask, %pass_thru
|
||||
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %ld : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compress1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
|
||||
// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @compress1(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
|
||||
// CHECK-NEXT: return
|
||||
func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [16] : vector<16xi1>
|
||||
vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compress2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @compress2(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
|
||||
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
|
||||
// CHECK-NEXT: return
|
||||
func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%mask = vector.constant_mask [0] : vector<16xi1>
|
||||
vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
|
|
|
@ -24,13 +24,25 @@ namespace {
|
|||
|
||||
struct TestVectorToVectorConversion
|
||||
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
|
||||
TestVectorToVectorConversion() = default;
|
||||
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
}
|
||||
|
||||
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *ctx = &getContext();
|
||||
patterns.insert<UnrollVectorPattern>(
|
||||
ctx,
|
||||
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
|
||||
filter));
|
||||
if (unroll) {
|
||||
patterns.insert<UnrollVectorPattern>(
|
||||
ctx,
|
||||
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
|
||||
filter));
|
||||
}
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
|
|
Loading…
Reference in New Issue