forked from OSchip/llvm-project
[mlir][tensor] Insert explicit tensor.cast ops for insert_slice src
If additional static type information can be deduced from a insert_slice's size operands, insert an explicit cast of the op's source operand. This enables other canonicalization patterns that are matching for tensor_cast ops such as `ForOpTensorCastFolder` in SCF. Differential Revision: https://reviews.llvm.org/D108617
This commit is contained in:
parent
0c36082963
commit
ebf35370ff
|
@ -1085,7 +1085,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Fold tensor_casts with insert_slice operations.
|
||||
/// Fold tensor_casts with insert_slice operations. If the source or destination
|
||||
/// tensor is a tensor_cast that removes static type information, the cast is
|
||||
/// folded into the insert_slice operation. E.g.:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
||||
/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
|
||||
/// ```
|
||||
///
|
||||
/// Note: When folding a cast on the destination tensor, the result of the
|
||||
/// insert_slice operation is casted to ensure that the type of the result did
|
||||
/// not change.
|
||||
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
|
||||
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
|
||||
|
||||
|
@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// If additional static type information can be deduced from a insert_slice's
|
||||
/// size operands, insert an explicit cast of the op's source operand. This
|
||||
/// enables other canonicalization patterns that are matching for tensor_cast
|
||||
/// ops such as `ForOpTensorCastFolder` in SCF.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
|
||||
/// : tensor<?x?xf32> into ...
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
|
||||
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
|
||||
/// : tensor<64x64xf32> into ...
|
||||
/// ```
|
||||
struct InsertSliceOpSourceCastInserter final
|
||||
: public OpRewritePattern<InsertSliceOp> {
|
||||
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
RankedTensorType srcType = insertSliceOp.getSourceType();
|
||||
if (srcType.getRank() != insertSliceOp.getType().getRank())
|
||||
return failure();
|
||||
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
for (int64_t i = 0; i < srcType.getRank(); ++i) {
|
||||
if (Optional<int64_t> constInt =
|
||||
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
|
||||
newSrcShape[i] = *constInt;
|
||||
}
|
||||
RankedTensorType newSrcType =
|
||||
RankedTensorType::get(newSrcShape, srcType.getElementType());
|
||||
if (srcType == newSrcType)
|
||||
return failure();
|
||||
|
||||
// srcType and newSrcType are different. Insert a cast.
|
||||
Value cast = rewriter.create<tensor::CastOp>(
|
||||
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
|
||||
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
||||
insertSliceOp, cast, insertSliceOp.dest(),
|
||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||
insertSliceOp.getMixedStrides());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
|
||||
context);
|
||||
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
|
||||
InsertSliceOpSourceCastInserter>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -666,7 +666,7 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
|
|||
return %res : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cond_prop
|
||||
func @cond_prop(%arg0 : i1) -> index {
|
||||
|
@ -707,6 +707,8 @@ func @cond_prop(%arg0 : i1) -> index {
|
|||
// CHECK-NEXT: return %[[if]] : index
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond1
|
||||
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
|
||||
%true = constant true
|
||||
|
@ -729,6 +731,8 @@ func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond2
|
||||
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
|
||||
%true = constant true
|
||||
|
@ -753,6 +757,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond3
|
||||
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
|
||||
|
@ -774,6 +779,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_cond_true
|
||||
func @while_cond_true() {
|
||||
|
|
|
@ -366,10 +366,11 @@ func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
|
|||
}
|
||||
// CHECK-LABEL: func @insert_slice_canonicalize
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
|
||||
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
|
||||
// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
|
||||
// CHECK-SAME: : tensor<?x?x?xf32> into tensor<?x?x?xf32>
|
||||
// CHEKC: return %[[RESULT]]
|
||||
// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -517,3 +518,17 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
|||
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
|
||||
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
|
||||
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
|
||||
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
|
||||
// CHECK: return %[[r]]
|
||||
func @insert_tensor_cast_on_insert_slice_src(
|
||||
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
|
||||
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
|
||||
return %r : tensor<?x?x?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue