[MLIR][SCF] Inline single block ExecuteRegionOp

This commit adds a canonicalization pass which inlines any single block execute region

Differential Revision: https://reviews.llvm.org/D104865
This commit is contained in:
William S. Moses 2021-06-24 12:33:54 -04:00
parent 3ba090e5f6
commit 44985872b8
3 changed files with 69 additions and 17 deletions

View File

@ -111,7 +111,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
// TODO: If the parent is a func like op (which would be the case if all other
// ops are from the std dialect), the inliner logic could be readily used to
// inline.
let hasCanonicalizer = 0;
let hasCanonicalizer = 1;
// TODO: can fold if it returns a constant.
// TODO: Single block execute_region ops can be readily inlined irrespective

View File

@ -73,6 +73,19 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Region &region, ValueRange blockArgs = {}) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
Block *block = &region.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.mergeBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}
///
/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
/// block+
@ -118,6 +131,37 @@ static LogicalResult verify(ExecuteRegionOp op) {
return success();
}
// Inline an ExecuteRegionOp if it only contains one block.
// "test.foo"() : () -> ()
// %v = scf.execute_region -> i64 {
// %x = "test.val"() : () -> i64
// scf.yield %x : i64
// }
// "test.bar"(%v) : (i64) -> ()
//
// becomes
//
// "test.foo"() : () -> ()
// %x = "test.val"() : () -> i64
// "test.bar"(%v) : (i64) -> ()
//
struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
if (op.region().getBlocks().size() != 1)
return failure();
replaceOpWithRegion(rewriter, op, op.region());
return success();
}
};
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SingleBlockExecuteInliner>(context);
}
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@ -444,19 +488,6 @@ LoopNest mlir::scf::buildLoopNest(
});
}
/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Region &region, ValueRange blockArgs = {}) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
Block *block = &region.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.mergeBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.

View File

@ -921,9 +921,30 @@ func @propagate_into_execute_region() {
}
"test.bar"(%v) : (i64) -> ()
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: scf.execute_region -> i64 {
// CHECK-NEXT: scf.yield %[[C2]] : i64
// CHECK-NEXT: }
// CHECK: "test.foo"
// CHECK-NEXT: "test.bar"(%[[C2]]) : (i64) -> ()
}
return
}
// -----
// CHECK-LABEL: func @execute_region_elim
func @execute_region_elim() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%x = "test.val"() : () -> i64
scf.yield %x : i64
}
"test.bar"(%v) : (i64) -> ()
}
return
}
// CHECK-NEXT: affine.for %arg0 = 0 to 100 {
// CHECK-NEXT: "test.foo"() : () -> ()
// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> ()
// CHECK-NEXT: }