forked from OSchip/llvm-project
[mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.
Summary: Adds two rewrite patterns for the vector ShapeCastOp. *) ShapeCastOp decomposer: decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps each on vector types. *) ShapeCastOp folder: folds canceling shape cast ops (e.g. shape_cast A -> B followed by shape_cast B -> A) away. Reviewers: nicolasvasilache, aartbik Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74327
This commit is contained in:
parent
5e37fb1776
commit
813bfffec3
|
@ -646,6 +646,90 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
|
||||
/// on vector types.
|
||||
struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if 'shapeCastOp' has tuple source/result type.
|
||||
auto sourceTupleType =
|
||||
shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
|
||||
auto resultTupleType =
|
||||
shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
|
||||
if (!sourceTupleType || !resultTupleType)
|
||||
return matchFailure();
|
||||
assert(sourceTupleType.size() == resultTupleType.size());
|
||||
|
||||
// Create single-vector ShapeCastOp for each source tuple element.
|
||||
Location loc = shapeCastOp.getLoc();
|
||||
SmallVector<Value, 8> resultElements;
|
||||
resultElements.reserve(resultTupleType.size());
|
||||
for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
|
||||
auto sourceElement = rewriter.create<vector::TupleGetOp>(
|
||||
loc, sourceTupleType.getType(i), shapeCastOp.source(),
|
||||
rewriter.getI64IntegerAttr(i));
|
||||
resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
|
||||
loc, resultTupleType.getType(i), sourceElement));
|
||||
}
|
||||
|
||||
// Replace 'shapeCastOp' with tuple of 'resultElements'.
|
||||
rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
|
||||
resultElements);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// The following MLIR with cancelling ShapeCastOps:
|
||||
//
|
||||
// %0 = source : vector<5x4x2xf32>
|
||||
// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
|
||||
// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
|
||||
// %3 = user %2 : vector<5x4x2xf32>
|
||||
//
|
||||
// Should canonicalize to the following:
|
||||
//
|
||||
// %0 = source : vector<5x4x2xf32>
|
||||
// %1 = user %0 : vector<5x4x2xf32>
|
||||
//
|
||||
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if 'shapeCastOp' has vector source/result type.
|
||||
auto sourceVectorType =
|
||||
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
|
||||
auto resultVectorType =
|
||||
shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
|
||||
if (!sourceVectorType || !resultVectorType)
|
||||
return matchFailure();
|
||||
|
||||
// Check if shape cast op source operand is also a shape cast op.
|
||||
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
|
||||
shapeCastOp.source().getDefiningOp());
|
||||
if (!sourceShapeCastOp)
|
||||
return matchFailure();
|
||||
auto operandSourceVectorType =
|
||||
sourceShapeCastOp.source().getType().cast<VectorType>();
|
||||
auto operandResultVectorType =
|
||||
sourceShapeCastOp.result().getType().cast<VectorType>();
|
||||
|
||||
// Check if shape cast operations invert each other.
|
||||
if (operandSourceVectorType != resultVectorType ||
|
||||
operandResultVectorType != sourceVectorType)
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Patter rewrite which forward tuple elements to their users.
|
||||
// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
|
||||
// -> User(Producer)
|
||||
|
@ -784,8 +868,8 @@ public:
|
|||
// TODO(andydavis) Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
|
||||
context);
|
||||
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
|
||||
SplitTransferWriteOp, TupleGetFolderOp>(context);
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||
|
|
|
@ -346,3 +346,62 @@ func @vector_transfers_vector_element_type() {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// Test that ShapeCastOp on tuple of vectors, decomposes to multiple
|
||||
// ShapeCastOps on vectors.
|
||||
// CHECK-LABEL: func @shape_cast_decomposition
|
||||
// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
|
||||
// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
|
||||
// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
|
||||
|
||||
func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
|
||||
%arg1 : vector<3x4x2xf32>)
|
||||
-> (vector<20x2xf32>, vector<12x2xf32>) {
|
||||
%0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
||||
%1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
|
||||
tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
%2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
%3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
|
||||
}
|
||||
|
||||
// Test that cancelling ShapeCastOps are canonicalized away.
|
||||
// EX:
|
||||
//
|
||||
// The following MLIR with cancelling ShapeCastOps:
|
||||
//
|
||||
// %0 = source : vector<5x4x2xf32>
|
||||
// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
|
||||
// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
|
||||
// %3 = user %2 : vector<5x4x2xf32>
|
||||
//
|
||||
// Should canonicalize to the following:
|
||||
//
|
||||
//
|
||||
// %0 = source : vector<5x4x2xf32>
|
||||
// %1 = user %0 : vector<5x4x2xf32>
|
||||
//
|
||||
|
||||
// ShapeCastOps on vectors.
|
||||
// CHECK-LABEL: func @shape_cast_fold
|
||||
// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
|
||||
|
||||
func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
|
||||
-> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
|
||||
%0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
||||
|
||||
%1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
|
||||
tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
|
||||
%2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
%3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
||||
|
||||
%4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
|
||||
%5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
|
||||
tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
||||
|
||||
%6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
||||
%7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
||||
|
||||
return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue