[mlir][vector] Add unrolling patterns for Transfer read/write

Adding unroll support for transfer read and transfer write operation. This
allows to pick the ideal size for the memory access for a given target.

Differential Revision: https://reviews.llvm.org/D89289
This commit is contained in:
Thomas Raoux 2020-10-15 09:47:58 -07:00
parent 4c1c88bbc1
commit edbdea7466
4 changed files with 216 additions and 60 deletions

View File

@ -69,6 +69,22 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
Operation *op,
ArrayRef<int64_t> targetShape);
/// Unroll a transfer_write op. Break up the vector source into a tuple of
/// vectors matching the given shape. Then store each element with its own
/// transfer_write.
///
/// Example:
/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
/// ->
/// %0 = vector.extract_slices %A, [2, 4], [1, 1] :
/// vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>>
/// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>>
/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32>
/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
/// declaratively.
template <typename OpTy>
@ -95,6 +111,12 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (std::is_same<OpTy, TransferWriteOp>::value) {
if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
return failure();
rewriter.eraseOp(op);
return success();
}
if (op.getOperation()->getNumResults() != 1)
return failure();
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);

View File

@ -511,35 +511,6 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
resultIndex = numVectors - 1;
}
// Entry point for unrolling declarative pattern rewrites.
SmallVector<Value, 1>
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
SmallVector<int64_t, 6> iterationBounds;
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
std::vector<VectorState> vectors;
unsigned resultIndex;
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
// Populate state for vector ContractionOp.
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
resultIndex);
} else {
// Populate state for vector elementwise op.
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
}
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
}
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void generateTransferOpSlices(
@ -615,6 +586,114 @@ static bool isIdentitySuffix(AffineMap map) {
return true;
}
/// Unroll transfer_read ops to the given shape and create an aggregate with all
/// the chunks.
static Value unrollTransferReadOp(vector::TransferReadOp readOp,
ArrayRef<int64_t> targetShape,
OpBuilder &builder) {
if (!isIdentitySuffix(readOp.permutation_map()))
return nullptr;
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
Location loc = readOp.getLoc();
auto memrefElementType =
readOp.memref().getType().cast<MemRefType>().getElementType();
auto tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
int64_t numSlices = tupleType.size();
SmallVector<Value, 4> vectorTupleValues(numSlices);
SmallVector<Value, 4> indices(readOp.indices().begin(),
readOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Get VectorType for slice 'i'.
auto sliceVectorType = tupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
loc, sliceVectorType, readOp.memref(), sliceIndices,
readOp.permutation_map(), readOp.padding(),
readOp.masked() ? *readOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
// Create tuple of splice transfer read operations.
Value tupleOp =
builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
// Replace 'readOp' with result 'insertSlicesResult'.
Value newVec = builder.create<vector::InsertSlicesOp>(
loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
builder.getI64ArrayAttr(strides));
return newVec;
}
// Entry point for unrolling declarative pattern rewrite for transfer_write op.
LogicalResult
mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape) {
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
VectorType sourceVectorType = writeOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
TupleType tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
Location loc = writeOp.getLoc();
Value tuple = builder.create<vector::ExtractSlicesOp>(
loc, tupleType, writeOp.vector(), targetShape, strides);
auto memrefElementType =
writeOp.memref().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(writeOp.indices().begin(),
writeOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
auto element = builder.create<vector::TupleGetOp>(
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
builder.create<vector::TransferWriteOp>(
loc, element.getResult(), writeOp.memref(), sliceIndices,
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
return success();
}
// Entry point for unrolling declarative pattern rewrites.
SmallVector<Value, 1>
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
SmallVector<int64_t, 6> iterationBounds;
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
std::vector<VectorState> vectors;
unsigned resultIndex;
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
return SmallVector<Value, 1>{
unrollTransferReadOp(readOp, targetShape, builder)};
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
// Populate state for vector ContractionOp.
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
resultIndex);
} else {
// Populate state for vector elementwise op.
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
}
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
}
namespace {
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
@ -636,43 +715,16 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
return failure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
auto sourceVectorType = extractSlicesOp.getSourceVectorType();
auto resultTupleType = extractSlicesOp.getResultTupleType();
SmallVector<int64_t, 4> sizes;
extractSlicesOp.getSizes(sizes);
SmallVector<int64_t, 4> strides;
extractSlicesOp.getStrides(strides);
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
Location loc = xferReadOp.getLoc();
auto memrefElementType =
xferReadOp.memref().getType().cast<MemRefType>().getElementType();
int64_t numSlices = resultTupleType.size();
SmallVector<Value, 4> vectorTupleValues(numSlices);
SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
xferReadOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Get VectorType for slice 'i'.
auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding(),
xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
};
generateTransferOpSlices(memrefElementType, sourceVectorType,
resultTupleType, sizes, strides, indices, rewriter,
createSlice);
// Create tuple of splice xfer read operations.
Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
vectorTupleValues);
// Replace 'xferReadOp' with result 'insertSlicesResult'.
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
extractSlicesOp.strides());
Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
if (!newVec)
return failure();
rewriter.replaceOp(xferReadOp, newVec);
return success();
}
};

View File

@ -0,0 +1,60 @@
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
// CHECK-LABEL: func @transfer_read_unroll
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
return %0 : vector<4x4xf32>
}
// CHECK-LABEL: func @transfer_write_unroll
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: return
func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
%c0 = constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}
// CHECK-LABEL: func @transfer_readwrite_unroll
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: return
func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}

View File

@ -156,6 +156,24 @@ struct TestVectorDistributePatterns
}
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
ArrayRef<int64_t>{2, 2}, ctx);
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
ArrayRef<int64_t>{2, 2}, ctx);
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
@ -205,6 +223,10 @@ void registerTestVectorConversions() {
"test-vector-unrolling-patterns",
"Test conversion patterns to unroll contract ops in the vector dialect");
PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
"test-vector-transfer-unrolling-patterns",
"Test conversion patterns to unroll transfer ops in the vector dialect");
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
"Test conversion patterns to split "