[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.

Differential Revision: https://reviews.llvm.org/D95305
This commit is contained in:
MaheshRavishankar 2021-01-26 23:21:33 -08:00
parent 48bdd676a1
commit 7c15e0f64c
2 changed files with 42 additions and 3 deletions

View File

@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
}
namespace {
struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
/// subtensor of this is also needed only for its shape. The result can be
/// replaced by a new init_tensor operation of the same size as the subtensor
/// op.
struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubTensorOp subtensorOp,
PatternRewriter &rewriter) const override {
if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>())
return failure();
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
subtensorOp, subtensorOp.sizes(),
llvm::to_vector<4>(llvm::map_range(
subtensorOp.static_sizes(),
[](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
subtensorOp.getSourceType().getElementType());
return success();
}
};
struct FoldInitTensorWithTensorReshapeOp
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
void InitTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
ReplaceStaticShapeDims>(context);
results
.insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
// CHECK-LABEL: func @keep_not_noop
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
// -----
func @fold_init_tensor_with_subtensor
(%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
{
%0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
%1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
: tensor<?x10x40xf32> to tensor<5x?x20xf32>
return %1 : tensor<5x?x20xf32>
}
// CHECK: func @fold_init_tensor_with_subtensor
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
// CHECK: return %[[T0]]