forked from OSchip/llvm-project
[mlir][vector] Extend vector transfer unrolling to support permutations and broadcast
Differential Revision: https://reviews.llvm.org/D101637
This commit is contained in:
parent
7417541fd8
commit
f44c76d6e9
|
@ -516,10 +516,12 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
|
|||
|
||||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||
/// calls 'fn' with linear index and indices for each slice.
|
||||
static void generateTransferOpSlices(
|
||||
Type shapedElementType, VectorType vectorType, TupleType tupleType,
|
||||
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
|
||||
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
|
||||
static void
|
||||
generateTransferOpSlices(Type shapedElementType, VectorType vectorType,
|
||||
TupleType tupleType, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides, ArrayRef<Value> indices,
|
||||
AffineMap permutationMap, OpBuilder &builder,
|
||||
function_ref<void(unsigned, ArrayRef<Value>)> fn) {
|
||||
// Compute strides w.r.t. to slice counts in each dimension.
|
||||
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
|
||||
assert(maybeDimSliceCounts.hasValue());
|
||||
|
@ -527,7 +529,6 @@ static void generateTransferOpSlices(
|
|||
auto sliceStrides = computeStrides(sliceDimCounts);
|
||||
|
||||
int64_t numSlices = tupleType.size();
|
||||
unsigned numSliceIndices = indices.size();
|
||||
// Compute 'indexOffset' at which to update 'indices', which is equal
|
||||
// to the memref rank (indices.size) minus the effective 'vectorRank'.
|
||||
// The effective 'vectorRank', is equal to the rank of the vector type
|
||||
|
@ -545,57 +546,38 @@ static void generateTransferOpSlices(
|
|||
assert(vectorRank >= sourceVectorElementType.getRank());
|
||||
vectorRank -= sourceVectorElementType.getRank();
|
||||
}
|
||||
unsigned indexOffset = numSliceIndices - vectorRank;
|
||||
|
||||
auto isBroadcast = [](AffineExpr expr) {
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
|
||||
return constExpr.getValue() == 0;
|
||||
return false;
|
||||
};
|
||||
auto *ctx = builder.getContext();
|
||||
for (unsigned i = 0; i < numSlices; ++i) {
|
||||
auto vectorOffsets = delinearize(sliceStrides, i);
|
||||
auto elementOffsets =
|
||||
computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
|
||||
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
|
||||
SmallVector<Value, 4> sliceIndices(numSliceIndices);
|
||||
for (unsigned j = 0; j < numSliceIndices; ++j) {
|
||||
if (j < indexOffset) {
|
||||
sliceIndices[j] = indices[j];
|
||||
} else {
|
||||
SmallVector<Value, 4> sliceIndices(indices.begin(), indices.end());
|
||||
for (auto dim : llvm::enumerate(permutationMap.getResults())) {
|
||||
if (isBroadcast(dim.value()))
|
||||
continue;
|
||||
unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
|
||||
auto expr = getAffineDimExpr(0, ctx) +
|
||||
getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
|
||||
getAffineConstantExpr(elementOffsets[dim.index()], ctx);
|
||||
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||
sliceIndices[j] = builder.create<AffineApplyOp>(
|
||||
indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
|
||||
}
|
||||
sliceIndices[pos] = builder.create<AffineApplyOp>(
|
||||
indices[pos].getLoc(), map, ArrayRef<Value>(indices[pos]));
|
||||
}
|
||||
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
|
||||
fn(i, sliceIndices);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if 'map' is a suffix of an identity affine map, false
|
||||
/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
static bool isIdentitySuffix(AffineMap map) {
|
||||
if (map.getNumDims() < map.getNumResults())
|
||||
return false;
|
||||
ArrayRef<AffineExpr> results = map.getResults();
|
||||
Optional<int> lastPos;
|
||||
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
auto expr = results[i].dyn_cast<AffineDimExpr>();
|
||||
if (!expr)
|
||||
return false;
|
||||
int currPos = static_cast<int>(expr.getPosition());
|
||||
if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
|
||||
return false;
|
||||
lastPos = currPos;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Unroll transfer_read ops to the given shape and create an aggregate with all
|
||||
/// the chunks.
|
||||
static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
OpBuilder &builder) {
|
||||
if (!isIdentitySuffix(readOp.permutation_map()))
|
||||
return nullptr;
|
||||
if (readOp.mask())
|
||||
return nullptr;
|
||||
auto sourceVectorType = readOp.getVectorType();
|
||||
|
@ -623,7 +605,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
targetShape, strides, indices,
|
||||
readOp.permutation_map(), builder, createSlice);
|
||||
|
||||
// Create tuple of splice transfer read operations.
|
||||
Value tupleOp =
|
||||
|
@ -641,8 +624,6 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
|||
ArrayRef<int64_t> targetShape,
|
||||
SmallVector<Value, 1> &result) {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
return failure();
|
||||
if (writeOp.mask())
|
||||
return failure();
|
||||
VectorType sourceVectorType = writeOp.getVectorType();
|
||||
|
@ -671,7 +652,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
|||
resultTensor = write->getResult(0);
|
||||
};
|
||||
generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
targetShape, strides, indices,
|
||||
writeOp.permutation_map(), builder, createSlice);
|
||||
if (resultTensor)
|
||||
result.push_back(resultTensor);
|
||||
return success();
|
||||
|
@ -729,11 +711,6 @@ public:
|
|||
if (readOp.mask())
|
||||
return failure();
|
||||
|
||||
// TODO: Support splitting TransferReadOp with non-identity permutation
|
||||
// maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(readOp.permutation_map()))
|
||||
return failure();
|
||||
|
||||
// Return unless there is only one user, and it is an ExtractSlicesOp.
|
||||
Value readResult = readOp.getResult();
|
||||
if (!readResult.hasOneUse())
|
||||
|
@ -778,11 +755,6 @@ public:
|
|||
if (writeOp.mask())
|
||||
return failure();
|
||||
|
||||
// TODO: Support splitting TransferWriteOp with non-identity permutation
|
||||
// maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
return failure();
|
||||
|
||||
// Fail to match unless this is writing a vector resulting from an
|
||||
// InsertSlicesOp.
|
||||
auto insertSlicesOp =
|
||||
|
@ -821,8 +793,8 @@ public:
|
|||
resultTensor = write->getResult(0);
|
||||
};
|
||||
generateTransferOpSlices(shapedElementType, resultVectorType,
|
||||
sourceTupleType, sizes, strides, indices, rewriter,
|
||||
createSlice);
|
||||
sourceTupleType, sizes, strides, indices,
|
||||
writeOp.permutation_map(), rewriter, createSlice);
|
||||
|
||||
if (resultTensor)
|
||||
rewriter.replaceOp(writeOp, ArrayRef<Value>(resultTensor));
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
|
@ -120,3 +120,94 @@ func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4
|
|||
%r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
|
||||
return %r: tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll_permutation
|
||||
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32>
|
||||
#map0 = affine_map<(d0, d1) -> (d1, d0)>
|
||||
func @transfer_read_unroll_permutation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
|
||||
return %0 : vector<4x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll_broadcast
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32>
|
||||
#map0 = affine_map<(d0, d1) -> (0, d1)>
|
||||
func @transfer_read_unroll_broadcast(%arg0 : memref<6x4xf32>) -> vector<6x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
|
||||
return %0 : vector<6x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll_broadcast_permuation
|
||||
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32>
|
||||
#map0 = affine_map<(d0, d1) -> (0, d0)>
|
||||
func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
|
||||
return %0 : vector<4x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll_different_rank
|
||||
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32>
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
func @transfer_read_unroll_different_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?xf32>, vector<6x4xf32>
|
||||
return %0 : vector<6x4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue