[mlir][tensor] Insert explicit tensor.cast ops for insert_slice src

If additional static type information can be deduced from a insert_slice's size operands, insert an explicit cast of the op's source operand.

This enables other canonicalization patterns that are matching for tensor_cast ops such as `ForOpTensorCastFolder` in SCF.

Differential Revision: https://reviews.llvm.org/D108617
This commit is contained in:
Matthias Springer 2021-08-24 19:41:16 +09:00
parent 0c36082963
commit ebf35370ff
3 changed files with 96 additions and 7 deletions

View File

@ -1085,7 +1085,24 @@ public:
}
};
/// Fold tensor_casts with insert_slice operations.
/// Fold tensor_casts with insert_slice operations. If the source or destination
/// tensor is a tensor_cast that removes static type information, the cast is
/// folded into the insert_slice operation. E.g.:
///
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
/// ```
///
/// Note: When folding a cast on the destination tensor, the result of the
/// insert_slice operation is casted to ensure that the type of the result did
/// not change.
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
return success();
}
};
/// If additional static type information can be deduced from a insert_slice's
/// size operands, insert an explicit cast of the op's source operand. This
/// enables other canonicalization patterns that are matching for tensor_cast
/// ops such as `ForOpTensorCastFolder` in SCF.
///
/// Example:
///
/// ```mlir
/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
/// : tensor<?x?xf32> into ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
/// : tensor<64x64xf32> into ...
/// ```
struct InsertSliceOpSourceCastInserter final
: public OpRewritePattern<InsertSliceOp> {
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType srcType = insertSliceOp.getSourceType();
if (srcType.getRank() != insertSliceOp.getType().getRank())
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
srcType.getShape().end());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (Optional<int64_t> constInt =
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
newSrcShape[i] = *constInt;
}
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
if (srcType == newSrcType)
return failure();
// srcType and newSrcType are different. Insert a cast.
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
rewriter.replaceOpWithNewOp<InsertSliceOp>(
insertSliceOp, cast, insertSliceOp.dest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
return success();
}
};
} // namespace
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
context);
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
InsertSliceOpSourceCastInserter>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -666,7 +666,7 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
return %res : tensor<1024x1024xf32>
}
// -----
// CHECK-LABEL: @cond_prop
func @cond_prop(%arg0 : i1) -> index {
@ -707,6 +707,8 @@ func @cond_prop(%arg0 : i1) -> index {
// CHECK-NEXT: return %[[if]] : index
// CHECK-NEXT:}
// -----
// CHECK-LABEL: @replace_if_with_cond1
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = constant true
@ -729,6 +731,8 @@ func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
// -----
// CHECK-LABEL: @replace_if_with_cond2
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
%true = constant true
@ -753,6 +757,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
// -----
// CHECK-LABEL: @replace_if_with_cond3
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
@ -774,6 +779,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
// -----
// CHECK-LABEL: @while_cond_true
func @while_cond_true() {

View File

@ -366,10 +366,11 @@ func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
}
// CHECK-LABEL: func @insert_slice_canonicalize
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> into tensor<?x?x?xf32>
// CHEKC: return %[[RESULT]]
// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
// CHECK: return %[[RESULT]]
// -----
@ -517,3 +518,17 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}
// -----
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
// CHECK: return %[[r]]
func @insert_tensor_cast_on_insert_slice_src(
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
return %r : tensor<?x?x?xf32>
}