From 7d6ef5caef80a24d170dee0f1fec54f3bc7fd979 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 28 Jul 2022 13:11:05 +0530 Subject: [PATCH] [mlir][tensor] Fold `tensor.cast` into `tensor.collapse_shape` op This commit folds a `tensor.cast` op into a `tensor.collapse_shape` op when following two conditions meet: 1. the `tensor.collapse_shape` op consumes result of the `tensor.cast` op. 2. `tensor.cast` op casts to a more dynamic version of the source tensor. This is added as a canonicalization pattern in `tensor.collapse_shape` op. Signed-Off-By: Gaurav Shukla Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D130650 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 40 +++++++++++++++++++--- mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index a9437634b285..2d91f45205e6 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -928,6 +928,36 @@ struct FoldReshapeWithFromElements : OpRewritePattern { } }; +// Fold CastOp into CollapseShapeOp when adding static information. +struct FoldCollapseOfCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp, + PatternRewriter &rewriter) const override { + auto castOp = collapseShapeOp.getSrc().getDefiningOp(); + if (!tensor::canFoldIntoConsumerOp(castOp)) + return failure(); + + RankedTensorType srcType = + castOp.getSource().getType().cast(); + RankedTensorType newResultType = computeTensorReshapeCollapsedType( + srcType, collapseShapeOp.getReassociationMaps()); + + if (newResultType == collapseShapeOp.getResultType()) { + rewriter.updateRootInPlace(collapseShapeOp, [&]() { + collapseShapeOp.getSrcMutable().assign(castOp.getSource()); + }); + } else { + auto newOp = rewriter.create( + collapseShapeOp.getLoc(), newResultType, castOp.getSource(), + collapseShapeOp.getReassociation()); + rewriter.replaceOpWithNewOp( + collapseShapeOp, collapseShapeOp.getResultType(), newOp); + } + return success(); + } +}; + } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -940,10 +970,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithFromElements>(context); + results + .add, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithFromElements, FoldCollapseOfCastOp>( + context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index d07f3e894e24..1eb1a5d7beca 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -673,6 +673,20 @@ func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor) // ----- +// CHECK-LABEL: func.func @collapse_of_cast( +// CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor { +// CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32> +// CHECK-NEXT %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor +// CHECK-NEXT return %[[CAST]] : tensor +func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor { + %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor + %2 = tensor.cast %1 : tensor to tensor + return %2 : tensor +} + +// ----- + func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<12x4xf32> into tensor<3x4x4xf32>