diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a6f3576c4240..2982132aa7d1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder, } namespace { -struct FoldWithTensorReshapeOp : public OpRewritePattern { +/// Since `init_tensor` operation creates a tensor needed only for its shape, a +/// subtensor of this is also needed only for its shape. The result can be +/// replaced by a new init_tensor operation of the same size as the subtensor +/// op. +struct FoldInitTensorWithSubTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subtensorOp, + PatternRewriter &rewriter) const override { + if (!subtensorOp.source().getDefiningOp()) + return failure(); + rewriter.replaceOpWithNewOp( + subtensorOp, subtensorOp.sizes(), + llvm::to_vector<4>(llvm::map_range( + subtensorOp.static_sizes(), + [](Attribute attr) { return attr.cast().getInt(); })), + subtensorOp.getSourceType().getElementType()); + return success(); + } +}; + +struct FoldInitTensorWithTensorReshapeOp + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, @@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern { void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results + .insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index cc00b98d376c..418d9d2195b9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor, %arg1 : tensor) // CHECK-LABEL: func @keep_not_noop // CHECK: %[[RESULT:.+]]:2 = linalg.generic // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @fold_init_tensor_with_subtensor + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = linalg.init_tensor[%arg0, 10, 40] : tensor + %1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_init_tensor_with_subtensor +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20] +// CHECK: return %[[T0]]