[mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.

Summary:
Adds two rewrite patterns for the vector ShapeCastOp.
*) ShapeCastOp decomposer: decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps each on vector types.
*) ShapeCastOp folder: folds canceling shape cast ops (e.g. shape_cast A -> B followed by shape_cast B -> A) away.

Reviewers: nicolasvasilache, aartbik

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74327
This commit is contained in:
Andy Davis 2020-02-11 12:57:57 -08:00
parent 5e37fb1776
commit 813bfffec3
2 changed files with 145 additions and 2 deletions

View File

@ -646,6 +646,90 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
}
};
/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
/// on vector types.
struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has tuple source/result type.
auto sourceTupleType =
shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
auto resultTupleType =
shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
if (!sourceTupleType || !resultTupleType)
return matchFailure();
assert(sourceTupleType.size() == resultTupleType.size());
// Create single-vector ShapeCastOp for each source tuple element.
Location loc = shapeCastOp.getLoc();
SmallVector<Value, 8> resultElements;
resultElements.reserve(resultTupleType.size());
for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
auto sourceElement = rewriter.create<vector::TupleGetOp>(
loc, sourceTupleType.getType(i), shapeCastOp.source(),
rewriter.getI64IntegerAttr(i));
resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
loc, resultTupleType.getType(i), sourceElement));
}
// Replace 'shapeCastOp' with tuple of 'resultElements'.
rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
resultElements);
return matchSuccess();
}
};
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
//
// Example:
//
// The following MLIR with cancelling ShapeCastOps:
//
// %0 = source : vector<5x4x2xf32>
// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
// %3 = user %2 : vector<5x4x2xf32>
//
// Should canonicalize to the following:
//
// %0 = source : vector<5x4x2xf32>
// %1 = user %0 : vector<5x4x2xf32>
//
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
auto resultVectorType =
shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
if (!sourceVectorType || !resultVectorType)
return matchFailure();
// Check if shape cast op source operand is also a shape cast op.
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
shapeCastOp.source().getDefiningOp());
if (!sourceShapeCastOp)
return matchFailure();
auto operandSourceVectorType =
sourceShapeCastOp.source().getType().cast<VectorType>();
auto operandResultVectorType =
sourceShapeCastOp.result().getType().cast<VectorType>();
// Check if shape cast operations invert each other.
if (operandSourceVectorType != resultVectorType ||
operandResultVectorType != sourceVectorType)
return matchFailure();
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
return matchSuccess();
}
};
// Patter rewrite which forward tuple elements to their users.
// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
// -> User(Producer)
@ -784,8 +868,8 @@ public:
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
context);
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
SplitTransferWriteOp, TupleGetFolderOp>(context);
}
void mlir::vector::populateVectorSlicesLoweringPatterns(

View File

@ -346,3 +346,62 @@ func @vector_transfers_vector_element_type() {
return
}
// Test that ShapeCastOp on tuple of vectors, decomposes to multiple
// ShapeCastOps on vectors.
// CHECK-LABEL: func @shape_cast_decomposition
// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
%arg1 : vector<3x4x2xf32>)
-> (vector<20x2xf32>, vector<12x2xf32>) {
%0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
%1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
tuple<vector<20x2xf32>, vector<12x2xf32>>
%2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
%3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
}
// Test that cancelling ShapeCastOps are canonicalized away.
// EX:
//
// The following MLIR with cancelling ShapeCastOps:
//
// %0 = source : vector<5x4x2xf32>
// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
// %3 = user %2 : vector<5x4x2xf32>
//
// Should canonicalize to the following:
//
//
// %0 = source : vector<5x4x2xf32>
// %1 = user %0 : vector<5x4x2xf32>
//
// ShapeCastOps on vectors.
// CHECK-LABEL: func @shape_cast_fold
// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
-> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
%0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
%1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
tuple<vector<20x2xf32>, vector<12x2xf32>>
%2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
%3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
%4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
%5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
%6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
%7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
}