forked from OSchip/llvm-project
[mlir][vector] Add unrolling patterns for Transfer read/write
Adding unroll support for transfer read and transfer write operation. This allows to pick the ideal size for the memory access for a given target. Differential Revision: https://reviews.llvm.org/D89289
This commit is contained in:
parent
4c1c88bbc1
commit
edbdea7466
|
@ -69,6 +69,22 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
|
|||
Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
|
||||
/// Unroll a transfer_write op. Break up the vector source into a tuple of
|
||||
/// vectors matching the given shape. Then store each element with its own
|
||||
/// transfer_write.
|
||||
///
|
||||
/// Example:
|
||||
/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
/// ->
|
||||
/// %0 = vector.extract_slices %A, [2, 4], [1, 1] :
|
||||
/// vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>>
|
||||
/// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>>
|
||||
/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32>
|
||||
/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
|
||||
/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
|
||||
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
|
||||
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
|
||||
/// declaratively.
|
||||
template <typename OpTy>
|
||||
|
@ -95,6 +111,12 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
|
|||
if (!maybeShapeRatio ||
|
||||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
|
||||
return failure();
|
||||
if (std::is_same<OpTy, TransferWriteOp>::value) {
|
||||
if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
if (op.getOperation()->getNumResults() != 1)
|
||||
return failure();
|
||||
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
|
||||
|
|
|
@ -511,35 +511,6 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
|
|||
resultIndex = numVectors - 1;
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
SmallVector<Value, 1>
|
||||
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||
|
||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
|
||||
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
|
||||
|
||||
std::vector<VectorState> vectors;
|
||||
unsigned resultIndex;
|
||||
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Populate state for vector ContractionOp.
|
||||
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
|
||||
resultIndex);
|
||||
} else {
|
||||
// Populate state for vector elementwise op.
|
||||
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
|
||||
}
|
||||
|
||||
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
||||
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
|
||||
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
|
||||
}
|
||||
|
||||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||
/// calls 'fn' with linear index and indices for each slice.
|
||||
static void generateTransferOpSlices(
|
||||
|
@ -615,6 +586,114 @@ static bool isIdentitySuffix(AffineMap map) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Unroll transfer_read ops to the given shape and create an aggregate with all
|
||||
/// the chunks.
|
||||
static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
OpBuilder &builder) {
|
||||
if (!isIdentitySuffix(readOp.permutation_map()))
|
||||
return nullptr;
|
||||
auto sourceVectorType = readOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
|
||||
|
||||
Location loc = readOp.getLoc();
|
||||
auto memrefElementType =
|
||||
readOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
auto tupleType = generateExtractSlicesOpResultType(
|
||||
sourceVectorType, targetShape, strides, builder);
|
||||
int64_t numSlices = tupleType.size();
|
||||
|
||||
SmallVector<Value, 4> vectorTupleValues(numSlices);
|
||||
SmallVector<Value, 4> indices(readOp.indices().begin(),
|
||||
readOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
// Get VectorType for slice 'i'.
|
||||
auto sliceVectorType = tupleType.getType(index);
|
||||
// Create split TransferReadOp for 'sliceUser'.
|
||||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
|
||||
loc, sliceVectorType, readOp.memref(), sliceIndices,
|
||||
readOp.permutation_map(), readOp.padding(),
|
||||
readOp.masked() ? *readOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
|
||||
// Create tuple of splice transfer read operations.
|
||||
Value tupleOp =
|
||||
builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
|
||||
// Replace 'readOp' with result 'insertSlicesResult'.
|
||||
Value newVec = builder.create<vector::InsertSlicesOp>(
|
||||
loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
|
||||
builder.getI64ArrayAttr(strides));
|
||||
return newVec;
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrite for transfer_write op.
|
||||
LogicalResult
|
||||
mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
return failure();
|
||||
VectorType sourceVectorType = writeOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
|
||||
TupleType tupleType = generateExtractSlicesOpResultType(
|
||||
sourceVectorType, targetShape, strides, builder);
|
||||
Location loc = writeOp.getLoc();
|
||||
Value tuple = builder.create<vector::ExtractSlicesOp>(
|
||||
loc, tupleType, writeOp.vector(), targetShape, strides);
|
||||
auto memrefElementType =
|
||||
writeOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
SmallVector<Value, 4> indices(writeOp.indices().begin(),
|
||||
writeOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
auto element = builder.create<vector::TupleGetOp>(
|
||||
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
|
||||
builder.create<vector::TransferWriteOp>(
|
||||
loc, element.getResult(), writeOp.memref(), sliceIndices,
|
||||
writeOp.permutation_map(),
|
||||
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
|
||||
targetShape, strides, indices, builder, createSlice);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
SmallVector<Value, 1>
|
||||
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||
|
||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
|
||||
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
|
||||
|
||||
std::vector<VectorState> vectors;
|
||||
unsigned resultIndex;
|
||||
|
||||
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
|
||||
return SmallVector<Value, 1>{
|
||||
unrollTransferReadOp(readOp, targetShape, builder)};
|
||||
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Populate state for vector ContractionOp.
|
||||
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
|
||||
resultIndex);
|
||||
} else {
|
||||
// Populate state for vector elementwise op.
|
||||
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
|
||||
}
|
||||
|
||||
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
||||
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
|
||||
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
|
||||
|
@ -636,43 +715,16 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
|||
return failure();
|
||||
|
||||
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
|
||||
auto sourceVectorType = extractSlicesOp.getSourceVectorType();
|
||||
auto resultTupleType = extractSlicesOp.getResultTupleType();
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
extractSlicesOp.getSizes(sizes);
|
||||
SmallVector<int64_t, 4> strides;
|
||||
extractSlicesOp.getStrides(strides);
|
||||
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
|
||||
|
||||
Location loc = xferReadOp.getLoc();
|
||||
auto memrefElementType =
|
||||
xferReadOp.memref().getType().cast<MemRefType>().getElementType();
|
||||
int64_t numSlices = resultTupleType.size();
|
||||
SmallVector<Value, 4> vectorTupleValues(numSlices);
|
||||
SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
|
||||
xferReadOp.indices().end());
|
||||
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
|
||||
// Get VectorType for slice 'i'.
|
||||
auto sliceVectorType = resultTupleType.getType(index);
|
||||
// Create split TransferReadOp for 'sliceUser'.
|
||||
// `masked` attribute propagates conservatively: if the coarse op didn't
|
||||
// need masking, the fine op doesn't either.
|
||||
vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
|
||||
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
|
||||
xferReadOp.permutation_map(), xferReadOp.padding(),
|
||||
xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
|
||||
};
|
||||
generateTransferOpSlices(memrefElementType, sourceVectorType,
|
||||
resultTupleType, sizes, strides, indices, rewriter,
|
||||
createSlice);
|
||||
|
||||
// Create tuple of splice xfer read operations.
|
||||
Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
|
||||
vectorTupleValues);
|
||||
// Replace 'xferReadOp' with result 'insertSlicesResult'.
|
||||
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
|
||||
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
|
||||
extractSlicesOp.strides());
|
||||
Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
|
||||
if (!newVec)
|
||||
return failure();
|
||||
rewriter.replaceOp(xferReadOp, newVec);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_unroll
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
|
||||
|
||||
func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
return %0 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_write_unroll
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_readwrite_unroll
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
|
||||
vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
|
||||
return
|
||||
}
|
|
@ -156,6 +156,24 @@ struct TestVectorDistributePatterns
|
|||
}
|
||||
};
|
||||
|
||||
struct TestVectorTransferUnrollingPatterns
|
||||
: public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
|
||||
ArrayRef<int64_t>{2, 2}, ctx);
|
||||
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
|
||||
ArrayRef<int64_t>{2, 2}, ctx);
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorTransferFullPartialSplitPatterns
|
||||
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
||||
FunctionPass> {
|
||||
|
@ -205,6 +223,10 @@ void registerTestVectorConversions() {
|
|||
"test-vector-unrolling-patterns",
|
||||
"Test conversion patterns to unroll contract ops in the vector dialect");
|
||||
|
||||
PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
|
||||
"test-vector-transfer-unrolling-patterns",
|
||||
"Test conversion patterns to unroll transfer ops in the vector dialect");
|
||||
|
||||
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
|
||||
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
|
||||
"Test conversion patterns to split "
|
||||
|
|
Loading…
Reference in New Issue