forked from OSchip/llvm-project
[mlir][vector] Support unrolling for transfer ops using tensors
Differential Revision: https://reviews.llvm.org/D93904
This commit is contained in:
parent
90bf3ecef4
commit
f9190c8681
|
@ -71,7 +71,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
|
|||
|
||||
/// Unroll a transfer_write op. Break up the vector source into a tuple of
|
||||
/// vectors matching the given shape. Then store each element with its own
|
||||
/// transfer_write.
|
||||
/// transfer_write. If the transfer_write takes a tensor source, return the
|
||||
/// unrolled Value in result.
|
||||
///
|
||||
/// Example:
|
||||
/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
|
@ -83,7 +84,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
|
|||
/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
|
||||
/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
|
||||
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
ArrayRef<int64_t> targetShape,
|
||||
SmallVector<Value, 1> &result);
|
||||
|
||||
/// Options that control the vector unrolling.
|
||||
struct UnrollVectorOptions {
|
||||
|
@ -143,9 +145,10 @@ struct UnrollVectorPattern : public RewritePattern {
|
|||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
|
||||
return failure();
|
||||
if (isa<TransferWriteOp>(op)) {
|
||||
if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
|
||||
SmallVector<Value, 1> result;
|
||||
if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result)))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
if (op->getNumResults() != 1)
|
||||
|
|
|
@ -515,7 +515,7 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
|
|||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||
/// calls 'fn' with linear index and indices for each slice.
|
||||
static void generateTransferOpSlices(
|
||||
Type memrefElementType, VectorType vectorType, TupleType tupleType,
|
||||
Type shapedElementType, VectorType vectorType, TupleType tupleType,
|
||||
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
|
||||
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
|
||||
// Compute strides w.r.t. to slice counts in each dimension.
|
||||
|
@ -539,9 +539,9 @@ static void generateTransferOpSlices(
|
|||
// vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
|
||||
//
|
||||
unsigned vectorRank = vectorType.getRank();
|
||||
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
|
||||
assert(vectorRank >= memrefVectorElementType.getRank());
|
||||
vectorRank -= memrefVectorElementType.getRank();
|
||||
if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
|
||||
assert(vectorRank >= sourceVectorElementType.getRank());
|
||||
vectorRank -= sourceVectorElementType.getRank();
|
||||
}
|
||||
unsigned indexOffset = numSliceIndices - vectorRank;
|
||||
|
||||
|
@ -598,8 +598,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
|
||||
|
||||
Location loc = readOp.getLoc();
|
||||
auto memrefElementType =
|
||||
readOp.source().getType().cast<MemRefType>().getElementType();
|
||||
auto shapedElementType =
|
||||
readOp.source().getType().cast<ShapedType>().getElementType();
|
||||
auto tupleType = generateExtractSlicesOpResultType(
|
||||
sourceVectorType, targetShape, strides, builder);
|
||||
int64_t numSlices = tupleType.size();
|
||||
|
@ -618,7 +618,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
readOp.permutation_map(), readOp.padding(),
|
||||
readOp.masked() ? *readOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
|
||||
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
|
||||
// Create tuple of splice transfer read operations.
|
||||
|
@ -634,7 +634,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
// Entry point for unrolling declarative pattern rewrite for transfer_write op.
|
||||
LogicalResult
|
||||
mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
ArrayRef<int64_t> targetShape,
|
||||
SmallVector<Value, 1> &result) {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
return failure();
|
||||
|
@ -645,20 +646,28 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
|||
Location loc = writeOp.getLoc();
|
||||
Value tuple = builder.create<vector::ExtractSlicesOp>(
|
||||
loc, tupleType, writeOp.vector(), targetShape, strides);
|
||||
auto memrefElementType =
|
||||
writeOp.source().getType().cast<MemRefType>().getElementType();
|
||||
auto shapedElementType =
|
||||
writeOp.source().getType().cast<ShapedType>().getElementType();
|
||||
SmallVector<Value, 4> indices(writeOp.indices().begin(),
|
||||
writeOp.indices().end());
|
||||
// If the TransferWrite returns a tensor, keep track of the last tensor
|
||||
// created.
|
||||
Value resultTensor;
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
auto element = builder.create<vector::TupleGetOp>(
|
||||
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
|
||||
builder.create<vector::TransferWriteOp>(
|
||||
loc, element.getResult(), writeOp.source(), sliceIndices,
|
||||
Operation *write = builder.create<vector::TransferWriteOp>(
|
||||
loc, element.getResult(),
|
||||
resultTensor ? resultTensor : writeOp.source(), sliceIndices,
|
||||
writeOp.permutation_map(),
|
||||
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
|
||||
if (!write->getResults().empty())
|
||||
resultTensor = write->getResult(0);
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
|
||||
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
if (resultTensor)
|
||||
result.push_back(resultTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -761,25 +770,32 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
insertSlicesOp.getStrides(strides);
|
||||
|
||||
Location loc = xferWriteOp.getLoc();
|
||||
auto memrefElementType =
|
||||
xferWriteOp.source().getType().cast<MemRefType>().getElementType();
|
||||
auto shapedElementType =
|
||||
xferWriteOp.source().getType().cast<ShapedType>().getElementType();
|
||||
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
|
||||
xferWriteOp.indices().end());
|
||||
Value resultTensor;
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
|
||||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
|
||||
Operation *write = rewriter.create<vector::TransferWriteOp>(
|
||||
loc, tupleOp.getOperand(index),
|
||||
resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
|
||||
xferWriteOp.permutation_map(),
|
||||
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
|
||||
if (!write->getResults().empty())
|
||||
resultTensor = write->getResult(0);
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, resultVectorType,
|
||||
generateTransferOpSlices(shapedElementType, resultVectorType,
|
||||
sourceTupleType, sizes, strides, indices, rewriter,
|
||||
createSlice);
|
||||
|
||||
// Erase old 'xferWriteOp'.
|
||||
rewriter.eraseOp(xferWriteOp);
|
||||
if (resultTensor)
|
||||
rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
|
||||
else
|
||||
rewriter.eraseOp(xferWriteOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -58,3 +58,65 @@ func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
|
|||
vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll_tensor
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
|
||||
|
||||
func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
|
||||
return %0 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_write_unroll_tensor
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
|
||||
|
||||
func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
|
||||
%arg1 : vector<4x4xf32>) -> tensor<4x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%r = vector.transfer_write %arg1, %arg0[%c0, %c0] :
|
||||
vector<4x4xf32>, tensor<4x4xf32>
|
||||
return %r: tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_readwrite_unroll_tensor
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[VTR1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[VTR2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
|
||||
|
||||
func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
|
||||
tensor<4x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
|
||||
%r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
|
||||
return %r: tensor<4x4xf32>
|
||||
}
|
||||
|
|
|
@ -530,6 +530,14 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
|
|||
// CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
|
||||
// CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
|
||||
// CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
|
||||
// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
|
||||
// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
|
||||
// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
|
||||
|
@ -544,7 +552,52 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
|
|||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
%1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
%cond = cmpf "ult", %0, %1 : vector<4x4xf32>
|
||||
%2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
|
||||
vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
// Vector transfer split pattern only support single user right now.
|
||||
%2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
%3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
%4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
|
||||
vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Check that vector.transfer read/write are split based on contract unrolling.
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
|
||||
// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
|
||||
|
||||
func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
|
||||
%arg1 : tensor<2x4xf32>,
|
||||
%arg2 : tensor<4x4xf32>) ->
|
||||
tensor<4x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
|
||||
tensor<4x2xf32>, vector<4x2xf32>
|
||||
%1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
|
||||
tensor<2x4xf32>, vector<2x4xf32>
|
||||
%2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
|
||||
tensor<4x4xf32>, vector<4x4xf32>
|
||||
%3 = vector.contract #contraction_trait1 %0, %1, %2
|
||||
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
||||
%r = vector.transfer_write %3, %arg2[%c0, %c0]
|
||||
: vector<4x4xf32>, tensor<4x4xf32>
|
||||
return %r : tensor<4x4xf32>
|
||||
}
|
||||
|
|
|
@ -28,7 +28,9 @@ struct TestVectorToVectorConversion
|
|||
OwningRewritePatternList patterns;
|
||||
auto *ctx = &getContext();
|
||||
patterns.insert<UnrollVectorPattern>(
|
||||
ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
|
||||
ctx,
|
||||
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
|
||||
filter));
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
|
@ -39,13 +41,14 @@ private:
|
|||
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
|
||||
if (isa<AddFOp, SelectOp, CmpFOp>(op))
|
||||
return SmallVector<int64_t, 4>(2, 2);
|
||||
if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
|
||||
return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
|
||||
}
|
||||
if (isa<vector::ContractionOp>(op))
|
||||
return SmallVector<int64_t, 4>(3, 2);
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
static LogicalResult filter(Operation *op) {
|
||||
return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorSlicesConversion
|
||||
|
|
Loading…
Reference in New Issue