[mlir][vector] Support unrolling for transfer ops using tensors

Differential Revision: https://reviews.llvm.org/D93904
This commit is contained in:
Thomas Raoux 2020-12-29 09:59:36 -08:00
parent 90bf3ecef4
commit f9190c8681
5 changed files with 166 additions and 29 deletions

View File

@ -71,7 +71,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
/// Unroll a transfer_write op. Break up the vector source into a tuple of
/// vectors matching the given shape. Then store each element with its own
/// transfer_write.
/// transfer_write. If the transfer_write takes a tensor source, return the
/// unrolled Value in result.
///
/// Example:
/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@ -83,7 +84,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
ArrayRef<int64_t> targetShape,
SmallVector<Value, 1> &result);
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
@ -143,9 +145,10 @@ struct UnrollVectorPattern : public RewritePattern {
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (isa<TransferWriteOp>(op)) {
if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
SmallVector<Value, 1> result;
if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result)))
return failure();
rewriter.eraseOp(op);
rewriter.replaceOp(op, result);
return success();
}
if (op->getNumResults() != 1)

View File

@ -515,7 +515,7 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void generateTransferOpSlices(
Type memrefElementType, VectorType vectorType, TupleType tupleType,
Type shapedElementType, VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
@ -539,9 +539,9 @@ static void generateTransferOpSlices(
// vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
//
unsigned vectorRank = vectorType.getRank();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
assert(vectorRank >= memrefVectorElementType.getRank());
vectorRank -= memrefVectorElementType.getRank();
if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
assert(vectorRank >= sourceVectorElementType.getRank());
vectorRank -= sourceVectorElementType.getRank();
}
unsigned indexOffset = numSliceIndices - vectorRank;
@ -598,8 +598,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
Location loc = readOp.getLoc();
auto memrefElementType =
readOp.source().getType().cast<MemRefType>().getElementType();
auto shapedElementType =
readOp.source().getType().cast<ShapedType>().getElementType();
auto tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
int64_t numSlices = tupleType.size();
@ -618,7 +618,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
readOp.permutation_map(), readOp.padding(),
readOp.masked() ? *readOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
// Create tuple of splice transfer read operations.
@ -634,7 +634,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
// Entry point for unrolling declarative pattern rewrite for transfer_write op.
LogicalResult
mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape) {
ArrayRef<int64_t> targetShape,
SmallVector<Value, 1> &result) {
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
@ -645,20 +646,28 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
Location loc = writeOp.getLoc();
Value tuple = builder.create<vector::ExtractSlicesOp>(
loc, tupleType, writeOp.vector(), targetShape, strides);
auto memrefElementType =
writeOp.source().getType().cast<MemRefType>().getElementType();
auto shapedElementType =
writeOp.source().getType().cast<ShapedType>().getElementType();
SmallVector<Value, 4> indices(writeOp.indices().begin(),
writeOp.indices().end());
// If the TransferWrite returns a tensor, keep track of the last tensor
// created.
Value resultTensor;
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
auto element = builder.create<vector::TupleGetOp>(
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
builder.create<vector::TransferWriteOp>(
loc, element.getResult(), writeOp.source(), sliceIndices,
Operation *write = builder.create<vector::TransferWriteOp>(
loc, element.getResult(),
resultTensor ? resultTensor : writeOp.source(), sliceIndices,
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
if (!write->getResults().empty())
resultTensor = write->getResult(0);
};
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
if (resultTensor)
result.push_back(resultTensor);
return success();
}
@ -761,25 +770,32 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
insertSlicesOp.getStrides(strides);
Location loc = xferWriteOp.getLoc();
auto memrefElementType =
xferWriteOp.source().getType().cast<MemRefType>().getElementType();
auto shapedElementType =
xferWriteOp.source().getType().cast<ShapedType>().getElementType();
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
xferWriteOp.indices().end());
Value resultTensor;
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index),
resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
xferWriteOp.permutation_map(),
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
if (!write->getResults().empty())
resultTensor = write->getResult(0);
};
generateTransferOpSlices(memrefElementType, resultVectorType,
generateTransferOpSlices(shapedElementType, resultVectorType,
sourceTupleType, sizes, strides, indices, rewriter,
createSlice);
// Erase old 'xferWriteOp'.
rewriter.eraseOp(xferWriteOp);
if (resultTensor)
rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
else
rewriter.eraseOp(xferWriteOp);
return success();
}
};

View File

@ -58,3 +58,65 @@ func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}
// CHECK-LABEL: func @transfer_read_unroll_tensor
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
return %0 : vector<4x4xf32>
}
// CHECK-LABEL: func @transfer_write_unroll_tensor
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
%arg1 : vector<4x4xf32>) -> tensor<4x4xf32> {
%c0 = constant 0 : index
%r = vector.transfer_write %arg1, %arg0[%c0, %c0] :
vector<4x4xf32>, tensor<4x4xf32>
return %r: tensor<4x4xf32>
}
// CHECK-LABEL: func @transfer_readwrite_unroll_tensor
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[VTR1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[VTR2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
tensor<4x4xf32> {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
%r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
return %r: tensor<4x4xf32>
}

View File

@ -530,6 +530,14 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
// CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
// CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
// CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
@ -544,7 +552,52 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%cond = cmpf "ult", %0, %1 : vector<4x4xf32>
%2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
// Vector transfer split pattern only support single user right now.
%2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}
// Check that vector.transfer read/write are split based on contract unrolling.
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
%arg1 : tensor<2x4xf32>,
%arg2 : tensor<4x4xf32>) ->
tensor<4x4xf32> {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
tensor<4x2xf32>, vector<4x2xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
tensor<2x4xf32>, vector<2x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
tensor<4x4xf32>, vector<4x4xf32>
%3 = vector.contract #contraction_trait1 %0, %1, %2
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
%r = vector.transfer_write %3, %arg2[%c0, %c0]
: vector<4x4xf32>, tensor<4x4xf32>
return %r : tensor<4x4xf32>
}

View File

@ -28,7 +28,9 @@ struct TestVectorToVectorConversion
OwningRewritePatternList patterns;
auto *ctx = &getContext();
patterns.insert<UnrollVectorPattern>(
ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
@ -39,13 +41,14 @@ private:
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
if (isa<AddFOp, SelectOp, CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
}
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
return llvm::None;
}
static LogicalResult filter(Operation *op) {
return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
}
};
struct TestVectorSlicesConversion