From 44985872b8a058272741703ee63bf7ca34212449 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Thu, 24 Jun 2021 12:33:54 -0400 Subject: [PATCH] [MLIR][SCF] Inline single block ExecuteRegionOp This commit adds a canonicalization pass which inlines any single block execute region Differential Revision: https://reviews.llvm.org/D104865 --- mlir/include/mlir/Dialect/SCF/SCFOps.td | 2 +- mlir/lib/Dialect/SCF/SCF.cpp | 57 +++++++++++++++++++------ mlir/test/Dialect/SCF/canonicalize.mlir | 27 ++++++++++-- 3 files changed, 69 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index a5584392aa61..c10441f59bd5 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -111,7 +111,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> { // TODO: If the parent is a func like op (which would be the case if all other // ops are from the std dialect), the inliner logic could be readily used to // inline. - let hasCanonicalizer = 0; + let hasCanonicalizer = 1; // TODO: can fold if it returns a constant. // TODO: Single block execute_region ops can be readily inlined irrespective diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index 0d22e17b245f..99d2386ced1b 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -73,6 +73,19 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { // ExecuteRegionOp //===----------------------------------------------------------------------===// +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + /// /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` /// block+ @@ -118,6 +131,37 @@ static LogicalResult verify(ExecuteRegionOp op) { return success(); } +// Inline an ExecuteRegionOp if it only contains one block. +// "test.foo"() : () -> () +// %v = scf.execute_region -> i64 { +// %x = "test.val"() : () -> i64 +// scf.yield %x : i64 +// } +// "test.bar"(%v) : (i64) -> () +// +// becomes +// +// "test.foo"() : () -> () +// %x = "test.val"() : () -> i64 +// "test.bar"(%v) : (i64) -> () +// +struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.region().getBlocks().size() != 1) + return failure(); + replaceOpWithRegion(rewriter, op, op.region()); + return success(); + } +}; + +void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<SingleBlockExecuteInliner>(context); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -444,19 +488,6 @@ LoopNest mlir::scf::buildLoopNest( }); } -/// Replaces the given op with the contents of the given single-block region, -/// using the operands of the block terminator to replace operation results. -static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, - Region ®ion, ValueRange blockArgs = {}) { - assert(llvm::hasSingleElement(region) && "expected single-region block"); - Block *block = ®ion.front(); - Operation *terminator = block->getTerminator(); - ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op, blockArgs); - rewriter.replaceOp(op, results); - rewriter.eraseOp(terminator); -} - namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 8b57bc8513cd..8692f2d9705e 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -921,9 +921,30 @@ func @propagate_into_execute_region() { } "test.bar"(%v) : (i64) -> () // CHECK: %[[C2:.*]] = constant 2 : i64 - // CHECK: scf.execute_region -> i64 { - // CHECK-NEXT: scf.yield %[[C2]] : i64 - // CHECK-NEXT: } + // CHECK: "test.foo" + // CHECK-NEXT: "test.bar"(%[[C2]]) : (i64) -> () } return } + +// ----- + +// CHECK-LABEL: func @execute_region_elim +func @execute_region_elim() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> () +// CHECK-NEXT: } +