forked from OSchip/llvm-project
[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
Differential Revision: https://reviews.llvm.org/D95305
This commit is contained in:
parent
48bdd676a1
commit
7c15e0f64c
|
@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
|
|||
}
|
||||
|
||||
namespace {
|
||||
struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
|
||||
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
|
||||
/// subtensor of this is also needed only for its shape. The result can be
|
||||
/// replaced by a new init_tensor operation of the same size as the subtensor
|
||||
/// op.
|
||||
struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
|
||||
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SubTensorOp subtensorOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
|
||||
subtensorOp, subtensorOp.sizes(),
|
||||
llvm::to_vector<4>(llvm::map_range(
|
||||
subtensorOp.static_sizes(),
|
||||
[](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
|
||||
subtensorOp.getSourceType().getElementType());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldInitTensorWithTensorReshapeOp
|
||||
: public OpRewritePattern<TensorReshapeOp> {
|
||||
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
|
@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
|
|||
|
||||
void InitTensorOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
|
||||
ReplaceStaticShapeDims>(context);
|
||||
results
|
||||
.insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
|
||||
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
|
|||
// CHECK-LABEL: func @keep_not_noop
|
||||
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
|
||||
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
|
||||
|
||||
// -----
|
||||
|
||||
func @fold_init_tensor_with_subtensor
|
||||
(%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
|
||||
{
|
||||
%0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
|
||||
%1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
|
||||
: tensor<?x10x40xf32> to tensor<5x?x20xf32>
|
||||
return %1 : tensor<5x?x20xf32>
|
||||
}
|
||||
// CHECK: func @fold_init_tensor_with_subtensor
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
|
||||
// CHECK: return %[[T0]]
|
||||
|
|
Loading…
Reference in New Issue